forked from guilhem_lavaux/JaxPM
adding distributed ops
This commit is contained in:
parent
1cb2739366
commit
137f4e5099
2 changed files with 317 additions and 0 deletions
223
jaxpm/distributed_ops.py
Normal file
223
jaxpm/distributed_ops.py
Normal file
|
@ -0,0 +1,223 @@
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax.lax as lax
|
||||||
|
from functools import partial
|
||||||
|
from jax.experimental.maps import xmap
|
||||||
|
from jax.experimental.pjit import pjit, PartitionSpec
|
||||||
|
|
||||||
|
import jax_cosmo as jc
|
||||||
|
import jaxpm as jpm
|
||||||
|
|
||||||
|
# TODO: add a way to configure axis resources from command line
|
||||||
|
axis_resources = {'x': 'nx', 'y': 'ny'}
|
||||||
|
mesh_size = {'nx': 2, 'ny': 2}
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({0: 'x', 2: 'y'},
|
||||||
|
{0: 'x', 2: 'y'},
|
||||||
|
{0: 'x', 2: 'y'}),
|
||||||
|
out_axes=({0: 'x', 2: 'y'}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def stack3d(a, b, c):
|
||||||
|
return jnp.stack([a, b, c], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({0: 'x', 2: 'y'}),
|
||||||
|
out_axes=({0: 'x', 2: 'y'}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def scalar_multiply(a, factor):
|
||||||
|
return a * factor
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({0: 'x', 2: 'y'},
|
||||||
|
{0: 'x', 2: 'y'}),
|
||||||
|
out_axes=({0: 'x', 2: 'y'}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def add(a, b):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=['x', 'y'],
|
||||||
|
out_axes=['x', 'y'],
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def fft3d(mesh):
|
||||||
|
""" Performs a 3D complex Fourier transform
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
3D FFT of the input, note that the dimensions of the output
|
||||||
|
are tranposed.
|
||||||
|
"""
|
||||||
|
mesh = jnp.fft.fft(mesh)
|
||||||
|
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
||||||
|
mesh = jnp.fft.fft(mesh)
|
||||||
|
mesh = lax.all_to_all(mesh, 'y', 0, 0)
|
||||||
|
return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y]
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=['x', 'y'],
|
||||||
|
out_axes=['x', 'y'],
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def ifft3d(mesh):
|
||||||
|
mesh = jnp.fft.ifft(mesh)
|
||||||
|
mesh = lax.all_to_all(mesh, 'y', 0, 0)
|
||||||
|
mesh = jnp.fft.ifft(mesh)
|
||||||
|
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
||||||
|
return jnp.fft.ifft(mesh).real
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=['x', 'y'],
|
||||||
|
out_axes={0: 'x', 2: 'y'},
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def normal(key, shape):
|
||||||
|
""" Generate a distributed random normal distributions
|
||||||
|
Args:
|
||||||
|
key: array of random keys with same layout as computational mesh
|
||||||
|
shape: logical shape of array to sample
|
||||||
|
"""
|
||||||
|
return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'],
|
||||||
|
shape[1]//mesh_size['ny']]+shape[2:])
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=(['x', 'y', ...],
|
||||||
|
[['x'], ['y'], ...]),
|
||||||
|
out_axes=['x', 'y', ...],
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
@jax.jit
|
||||||
|
def scale_by_power_spectrum(kfield, kvec, k, pk):
|
||||||
|
kx, ky, kz = kvec
|
||||||
|
kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2)
|
||||||
|
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=(['x', 'y', 'z'],
|
||||||
|
[['x'], ['y'], ['z']]),
|
||||||
|
out_axes=(['x', 'y', 'z']),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def gradient_laplace_kernel(kfield, kvec):
|
||||||
|
kx, ky, kz = kvec
|
||||||
|
kk = (kx**2 + ky**2 + kz**2)
|
||||||
|
kernel = jnp.where(kk == 0, 1., 1./kk)
|
||||||
|
return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
|
||||||
|
kfield * kernel * 1j * 1 / 6.0 *
|
||||||
|
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||||
|
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=([...]),
|
||||||
|
out_axes={0: 'x', 2: 'y'},
|
||||||
|
axis_sizes={'x': mesh_size['nx'],
|
||||||
|
'y': mesh_size['ny']},
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def meshgrid(x, y, z):
|
||||||
|
""" Generates a mesh grid of appropriate size for the
|
||||||
|
computational mesh we have.
|
||||||
|
"""
|
||||||
|
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({0: 'x', 2: 'y'}),
|
||||||
|
out_axes=({0: 'x', 2: 'y'}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def cic_paint(pos, mesh_shape, halo_size=0):
|
||||||
|
|
||||||
|
mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size,
|
||||||
|
mesh_shape[1]//mesh_size['ny']+2*halo_size]
|
||||||
|
+ mesh_shape[2:])
|
||||||
|
|
||||||
|
# Paint particles
|
||||||
|
mesh = jpm.cic_paint(mesh, pos.reshape(-1, 3) +
|
||||||
|
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||||
|
|
||||||
|
# Perform halo exchange
|
||||||
|
# Halo exchange along x
|
||||||
|
left = lax.pshuffle(mesh[-halo_size:],
|
||||||
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
|
axis_name='x')
|
||||||
|
right = lax.pshuffle(mesh[:halo_size],
|
||||||
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
|
axis_name='x')
|
||||||
|
mesh = mesh.at[:halo_size].add(left)
|
||||||
|
mesh = mesh.at[-halo_size:].add(right)
|
||||||
|
|
||||||
|
# Halo exchange along y
|
||||||
|
left = lax.pshuffle(mesh[:, -halo_size:],
|
||||||
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
|
axis_name='y')
|
||||||
|
right = lax.pshuffle(mesh[:, :halo_size],
|
||||||
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
|
axis_name='y')
|
||||||
|
mesh = mesh.at[:, :halo_size].add(left)
|
||||||
|
mesh = mesh.at[:, -halo_size:].add(right)
|
||||||
|
|
||||||
|
# removing halo and returning mesh
|
||||||
|
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||||
|
|
||||||
|
|
||||||
|
@partial(xmap,
|
||||||
|
in_axes=({0: 'x', 2: 'y'},
|
||||||
|
{0: 'x', 2: 'y'}),
|
||||||
|
out_axes=({0: 'x', 2: 'y'}),
|
||||||
|
axis_resources=axis_resources)
|
||||||
|
def cic_read(mesh, pos, halo_size):
|
||||||
|
|
||||||
|
# Halo exchange to grab neighboring borders
|
||||||
|
# Exchange along x
|
||||||
|
left = lax.pshuffle(mesh[-halo_size:],
|
||||||
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
|
axis_name='x')
|
||||||
|
right = lax.pshuffle(mesh[:halo_size],
|
||||||
|
perm=range(mesh_size['nx'])[::-1],
|
||||||
|
axis_name='x')
|
||||||
|
mesh = jnp.concatenate([left, mesh, right], axis=0)
|
||||||
|
# Exchange along y
|
||||||
|
left = lax.pshuffle(mesh[:, -halo_size:],
|
||||||
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
|
axis_name='y')
|
||||||
|
right = lax.pshuffle(mesh[:, :halo_size],
|
||||||
|
perm=range(mesh_size['ny'])[::-1],
|
||||||
|
axis_name='y')
|
||||||
|
mesh = jnp.concatenate([left, mesh, right], axis=1)
|
||||||
|
|
||||||
|
# Reading field at particles positions
|
||||||
|
res = jpm.painting.cic_read(mesh, pos.reshape(-1, 3) +
|
||||||
|
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
@partial(pjit,
|
||||||
|
in_axis_resources=PartitionSpec('nx', 'ny'),
|
||||||
|
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
|
||||||
|
def reshape_dense_to_split(x):
|
||||||
|
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
|
||||||
|
Changes the logical shape of the array, but no shuffling of the
|
||||||
|
data should be necessary
|
||||||
|
"""
|
||||||
|
shape = list(x.shape)
|
||||||
|
return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'],
|
||||||
|
mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:])
|
||||||
|
|
||||||
|
|
||||||
|
@partial(pjit,
|
||||||
|
in_axis_resources=PartitionSpec('nx', None, 'ny', None),
|
||||||
|
out_axis_resources=PartitionSpec('nx', 'ny'))
|
||||||
|
def reshape_split_to_dense(x):
|
||||||
|
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
|
||||||
|
Changes the logical shape of the array, but no shuffling of the
|
||||||
|
data should be necessary
|
||||||
|
"""
|
||||||
|
shape = list(x.shape)
|
||||||
|
return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:])
|
94
jaxpm/distributed_pm.py
Normal file
94
jaxpm/distributed_pm.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
import jax
|
||||||
|
from jax.lax import linear_solve_p
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax.experimental.maps import xmap
|
||||||
|
from functools import partial
|
||||||
|
import jax_cosmo as jc
|
||||||
|
|
||||||
|
from jaxpm.kernels import fftk
|
||||||
|
import jaxpm.distributed_ops as dops
|
||||||
|
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||||
|
|
||||||
|
|
||||||
|
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16):
|
||||||
|
"""
|
||||||
|
Computes gravitational forces on particles using a PM scheme
|
||||||
|
"""
|
||||||
|
if mesh_shape is None:
|
||||||
|
mesh_shape = delta_k.shape
|
||||||
|
kvec = [k.squeeze() for k in fftk(mesh_shape)]
|
||||||
|
|
||||||
|
if delta_k is None:
|
||||||
|
delta = dops.cic_paint(positions, mesh_shape, halo_size)
|
||||||
|
delta_k = dops.fft3d(dops.reshape_split_to_dense(delta))
|
||||||
|
|
||||||
|
forces_k = dops.gradient_laplace_kernel(delta_k, kvec)
|
||||||
|
|
||||||
|
# Recovers forces at particle positions
|
||||||
|
forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)),
|
||||||
|
positions, halo_size) for f in forces_k]
|
||||||
|
|
||||||
|
return dops.stack3d(*forces)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True):
|
||||||
|
"""
|
||||||
|
Generate initial conditions.
|
||||||
|
Seed should have the dimension of the computational mesh
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Sample normal field
|
||||||
|
field = dops.normal(seed, mesh_shape)
|
||||||
|
|
||||||
|
# Go to Fourier space
|
||||||
|
field = dops.fft3d(dops.reshape_split_to_dense(field))
|
||||||
|
|
||||||
|
# Rescaling k to physical units
|
||||||
|
kvec = [k.squeeze() / box_size[i] * mesh_shape[i]
|
||||||
|
for i, k in enumerate(fftk(mesh_shape, symmetric=False))]
|
||||||
|
k = jnp.logspace(-4, 2, 256)
|
||||||
|
pk = jc.power.linear_matter_power(cosmo, k)
|
||||||
|
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
|
||||||
|
) / (box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
|
field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk))
|
||||||
|
|
||||||
|
if return_Fourier:
|
||||||
|
return field
|
||||||
|
else:
|
||||||
|
return dops.reshape_dense_to_split(dops.ifft3d(field))
|
||||||
|
|
||||||
|
|
||||||
|
def lpt(cosmo, initial_conditions, positions, a):
|
||||||
|
"""
|
||||||
|
Computes first order LPT displacement
|
||||||
|
"""
|
||||||
|
initial_force = pm_forces(positions, delta_k=initial_conditions)
|
||||||
|
a = jnp.atleast_1d(a)
|
||||||
|
dx = dops.scalar_multiply(initial_force * growth_factor(cosmo, a))
|
||||||
|
p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) *
|
||||||
|
jnp.sqrt(jc.background.Esqr(cosmo, a)))
|
||||||
|
return dx, p
|
||||||
|
|
||||||
|
|
||||||
|
def make_ode_fn(mesh_shape):
|
||||||
|
|
||||||
|
def nbody_ode(state, a, cosmo):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
pos, vel = state
|
||||||
|
|
||||||
|
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
|
# Computes the update of position (drift)
|
||||||
|
dpos = dops.scalar_multiply(
|
||||||
|
vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
|
||||||
|
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
dvel = dops.scalar_multiply(
|
||||||
|
forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))))
|
||||||
|
|
||||||
|
return dpos, dvel
|
||||||
|
|
||||||
|
return nbody_ode
|
Loading…
Add table
Reference in a new issue