2024-07-09 14:54:34 -04:00
|
|
|
from functools import partial
|
|
|
|
|
2022-10-22 07:17:29 -04:00
|
|
|
import jax
|
|
|
|
import jax.lax as lax
|
2024-07-09 14:54:34 -04:00
|
|
|
import jax.numpy as jnp
|
|
|
|
import jax_cosmo as jc
|
2022-10-22 07:17:29 -04:00
|
|
|
from jax.experimental.maps import xmap
|
2024-07-09 14:54:34 -04:00
|
|
|
from jax.experimental.pjit import PartitionSpec, pjit
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
import jaxpm.painting as paint
|
|
|
|
|
|
|
|
# TODO: add a way to configure axis resources from command line
|
|
|
|
axis_resources = {'x': 'nx', 'y': 'ny'}
|
|
|
|
mesh_size = {'nx': 2, 'ny': 2}
|
|
|
|
|
|
|
|
|
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}, {
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}, {
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}),
|
|
|
|
out_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}),
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def stack3d(a, b, c):
|
|
|
|
return jnp.stack([a, b, c], axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}, [...]),
|
|
|
|
out_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}),
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def scalar_multiply(a, factor):
|
|
|
|
return a * factor
|
|
|
|
|
|
|
|
|
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}, {
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}),
|
|
|
|
out_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}),
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def add(a, b):
|
|
|
|
return a + b
|
|
|
|
|
|
|
|
|
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=['x', 'y', ...],
|
|
|
|
out_axes=['x', 'y', ...],
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def fft3d(mesh):
|
|
|
|
""" Performs a 3D complex Fourier transform
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mesh: a real 3D tensor of shape [Nx, Ny, Nz]
|
|
|
|
|
|
|
|
Returns:
|
2024-07-09 14:54:34 -04:00
|
|
|
3D FFT of the input, note that the dimensions of the output
|
2022-10-22 07:17:29 -04:00
|
|
|
are tranposed.
|
|
|
|
"""
|
|
|
|
mesh = jnp.fft.fft(mesh)
|
|
|
|
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
|
|
|
mesh = jnp.fft.fft(mesh)
|
|
|
|
mesh = lax.all_to_all(mesh, 'y', 0, 0)
|
|
|
|
return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y]
|
|
|
|
|
|
|
|
|
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=['x', 'y', ...],
|
|
|
|
out_axes=['x', 'y', ...],
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def ifft3d(mesh):
|
|
|
|
mesh = jnp.fft.ifft(mesh)
|
|
|
|
mesh = lax.all_to_all(mesh, 'y', 0, 0)
|
|
|
|
mesh = jnp.fft.ifft(mesh)
|
|
|
|
mesh = lax.all_to_all(mesh, 'x', 0, 0)
|
|
|
|
return jnp.fft.ifft(mesh).real
|
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-10-22 07:17:29 -04:00
|
|
|
def normal(key, shape=[]):
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-10-22 07:17:29 -04:00
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=['x', 'y', ...],
|
|
|
|
out_axes={
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
},
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def fn(key):
|
|
|
|
""" Generate a distributed random normal distributions
|
|
|
|
Args:
|
|
|
|
key: array of random keys with same layout as computational mesh
|
|
|
|
shape: logical shape of array to sample
|
|
|
|
"""
|
2024-07-09 14:54:34 -04:00
|
|
|
return jax.random.normal(
|
|
|
|
key,
|
|
|
|
shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] +
|
|
|
|
shape[2:])
|
|
|
|
|
2022-10-22 07:17:29 -04:00
|
|
|
return fn(key)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]),
|
2022-10-22 07:17:29 -04:00
|
|
|
out_axes=['x', 'y', ...],
|
|
|
|
axis_resources=axis_resources)
|
|
|
|
@jax.jit
|
|
|
|
def scale_by_power_spectrum(kfield, kvec, k, pk):
|
|
|
|
kx, ky, kz = kvec
|
2024-07-09 14:54:34 -04:00
|
|
|
kk = jnp.sqrt(kx**2 + ky**2 + kz**2)
|
2022-10-22 07:17:29 -04:00
|
|
|
return kfield * jc.scipy.interpolate.interp(kk, k, pk)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]),
|
2022-10-22 07:17:29 -04:00
|
|
|
out_axes=(['x', 'y', 'z']),
|
|
|
|
axis_resources=axis_resources)
|
|
|
|
def gradient_laplace_kernel(kfield, kvec):
|
|
|
|
kx, ky, kz = kvec
|
|
|
|
kk = (kx**2 + ky**2 + kz**2)
|
2024-07-09 14:54:34 -04:00
|
|
|
kernel = jnp.where(kk == 0, 1., 1. / kk)
|
|
|
|
return (kfield * kernel * 1j * 1 / 6.0 *
|
|
|
|
(8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 /
|
|
|
|
6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j *
|
|
|
|
1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)))
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
|
|
|
|
@partial(xmap,
|
|
|
|
in_axes=([...]),
|
2024-07-09 14:54:34 -04:00
|
|
|
out_axes={
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
},
|
|
|
|
axis_sizes={
|
|
|
|
'x': mesh_size['nx'],
|
|
|
|
'y': mesh_size['ny']
|
|
|
|
},
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def meshgrid(x, y, z):
|
2024-07-09 14:54:34 -04:00
|
|
|
""" Generates a mesh grid of appropriate size for the
|
2022-10-22 07:17:29 -04:00
|
|
|
computational mesh we have.
|
|
|
|
"""
|
2024-07-09 14:54:34 -04:00
|
|
|
return jnp.stack(jnp.meshgrid(x, y, z), axis=-1)
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
|
|
|
|
def cic_paint(pos, mesh_shape, halo_size=0):
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-10-22 07:17:29 -04:00
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}),
|
|
|
|
out_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}),
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def fn(pos):
|
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
mesh = jnp.zeros([
|
|
|
|
mesh_shape[0] // mesh_size['nx'] +
|
|
|
|
2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size
|
|
|
|
] + mesh_shape[2:])
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
# Paint particles
|
2024-07-09 14:54:34 -04:00
|
|
|
mesh = paint.cic_paint(
|
|
|
|
mesh,
|
|
|
|
pos.reshape(-1, 3) +
|
|
|
|
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
# Perform halo exchange
|
|
|
|
# Halo exchange along x
|
2024-07-09 14:54:34 -04:00
|
|
|
left = lax.pshuffle(mesh[-2 * halo_size:],
|
2022-10-22 07:17:29 -04:00
|
|
|
perm=range(mesh_size['nx'])[::-1],
|
|
|
|
axis_name='x')
|
2024-07-09 14:54:34 -04:00
|
|
|
right = lax.pshuffle(mesh[:2 * halo_size],
|
2022-10-22 07:17:29 -04:00
|
|
|
perm=range(mesh_size['nx'])[::-1],
|
|
|
|
axis_name='x')
|
2024-07-09 14:54:34 -04:00
|
|
|
mesh = mesh.at[:2 * halo_size].add(left)
|
|
|
|
mesh = mesh.at[-2 * halo_size:].add(right)
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
# Halo exchange along y
|
2024-07-09 14:54:34 -04:00
|
|
|
left = lax.pshuffle(mesh[:, -2 * halo_size:],
|
2022-10-22 07:17:29 -04:00
|
|
|
perm=range(mesh_size['ny'])[::-1],
|
|
|
|
axis_name='y')
|
2024-07-09 14:54:34 -04:00
|
|
|
right = lax.pshuffle(mesh[:, :2 * halo_size],
|
2022-10-22 07:17:29 -04:00
|
|
|
perm=range(mesh_size['ny'])[::-1],
|
|
|
|
axis_name='y')
|
2024-07-09 14:54:34 -04:00
|
|
|
mesh = mesh.at[:, :2 * halo_size].add(left)
|
|
|
|
mesh = mesh.at[:, -2 * halo_size:].add(right)
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
# removing halo and returning mesh
|
|
|
|
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
|
|
|
|
|
|
|
return fn(pos)
|
|
|
|
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-10-22 07:17:29 -04:00
|
|
|
def cic_read(mesh, pos, halo_size=0):
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-10-22 07:17:29 -04:00
|
|
|
@partial(xmap,
|
2024-07-09 14:54:34 -04:00
|
|
|
in_axes=(
|
|
|
|
{
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
},
|
|
|
|
{
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
},
|
|
|
|
),
|
|
|
|
out_axes=({
|
|
|
|
0: 'x',
|
|
|
|
2: 'y'
|
|
|
|
}),
|
2022-10-22 07:17:29 -04:00
|
|
|
axis_resources=axis_resources)
|
|
|
|
def fn(mesh, pos):
|
|
|
|
|
|
|
|
# Halo exchange to grab neighboring borders
|
|
|
|
# Exchange along x
|
|
|
|
left = lax.pshuffle(mesh[-halo_size:],
|
|
|
|
perm=range(mesh_size['nx'])[::-1],
|
|
|
|
axis_name='x')
|
|
|
|
right = lax.pshuffle(mesh[:halo_size],
|
|
|
|
perm=range(mesh_size['nx'])[::-1],
|
|
|
|
axis_name='x')
|
|
|
|
mesh = jnp.concatenate([left, mesh, right], axis=0)
|
|
|
|
# Exchange along y
|
|
|
|
left = lax.pshuffle(mesh[:, -halo_size:],
|
|
|
|
perm=range(mesh_size['ny'])[::-1],
|
|
|
|
axis_name='y')
|
|
|
|
right = lax.pshuffle(mesh[:, :halo_size],
|
|
|
|
perm=range(mesh_size['ny'])[::-1],
|
|
|
|
axis_name='y')
|
|
|
|
mesh = jnp.concatenate([left, mesh, right], axis=1)
|
|
|
|
|
|
|
|
# Reading field at particles positions
|
2024-07-09 14:54:34 -04:00
|
|
|
res = paint.cic_read(
|
|
|
|
mesh,
|
|
|
|
pos.reshape(-1, 3) +
|
|
|
|
jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]))
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
return res.reshape(pos.shape[:-1])
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2022-10-22 07:17:29 -04:00
|
|
|
return fn(mesh, pos)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(pjit,
|
|
|
|
in_axis_resources=PartitionSpec('nx', 'ny'),
|
|
|
|
out_axis_resources=PartitionSpec('nx', None, 'ny', None))
|
|
|
|
def reshape_dense_to_split(x):
|
|
|
|
""" Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z]
|
2024-07-09 14:54:34 -04:00
|
|
|
Changes the logical shape of the array, but no shuffling of the
|
2022-10-22 07:17:29 -04:00
|
|
|
data should be necessary
|
|
|
|
"""
|
|
|
|
shape = list(x.shape)
|
2024-07-09 14:54:34 -04:00
|
|
|
return x.reshape([
|
|
|
|
mesh_size['nx'], shape[0] //
|
|
|
|
mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny']
|
|
|
|
] + shape[2:])
|
2022-10-22 07:17:29 -04:00
|
|
|
|
|
|
|
|
|
|
|
@partial(pjit,
|
|
|
|
in_axis_resources=PartitionSpec('nx', None, 'ny', None),
|
|
|
|
out_axis_resources=PartitionSpec('nx', 'ny'))
|
|
|
|
def reshape_split_to_dense(x):
|
|
|
|
""" Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z]
|
2024-07-09 14:54:34 -04:00
|
|
|
Changes the logical shape of the array, but no shuffling of the
|
2022-10-22 07:17:29 -04:00
|
|
|
data should be necessary
|
|
|
|
"""
|
|
|
|
shape = list(x.shape)
|
2024-07-09 14:54:34 -04:00
|
|
|
return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:])
|