2024-12-20 11:44:02 +01:00
|
|
|
from functools import partial
|
|
|
|
|
2022-02-13 21:36:03 +01:00
|
|
|
import jax
|
|
|
|
import jax.lax as lax
|
2024-07-09 14:54:34 -04:00
|
|
|
import jax.numpy as jnp
|
2024-12-20 11:44:02 +01:00
|
|
|
from jax.sharding import NamedSharding
|
|
|
|
from jax.sharding import PartitionSpec as P
|
2024-07-09 14:54:34 -04:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
|
|
|
ifft3d, slice_pad, slice_unpad)
|
2024-07-09 14:54:34 -04:00
|
|
|
from jaxpm.kernels import cic_compensation, fftk
|
2024-12-20 11:44:02 +01:00
|
|
|
from jaxpm.painting_utils import gather, scatter
|
2022-02-13 21:36:03 +01:00
|
|
|
|
2022-03-26 00:06:34 +01:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
2024-07-09 14:54:34 -04:00
|
|
|
""" Paints positions onto mesh
|
2024-12-20 11:44:02 +01:00
|
|
|
mesh: [nx, ny, nz]
|
|
|
|
displacement field: [nx, ny, nz, 3]
|
|
|
|
"""
|
|
|
|
|
|
|
|
positions = positions.reshape([-1, 3])
|
2024-07-09 14:54:34 -04:00
|
|
|
positions = jnp.expand_dims(positions, 1)
|
|
|
|
floor = jnp.floor(positions)
|
|
|
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
|
|
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
|
|
|
|
|
|
|
neighboor_coords = floor + connection
|
|
|
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
|
|
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
|
|
|
if weight is not None:
|
2024-12-20 11:44:02 +01:00
|
|
|
if jnp.isscalar(weight):
|
|
|
|
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
|
|
|
else:
|
|
|
|
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
|
|
|
kernel)
|
2024-07-09 14:54:34 -04:00
|
|
|
|
|
|
|
neighboor_coords = jnp.mod(
|
|
|
|
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
2024-12-20 11:44:02 +01:00
|
|
|
jnp.array(grid_mesh.shape))
|
2024-07-09 14:54:34 -04:00
|
|
|
|
|
|
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
|
|
|
inserted_window_dims=(0, 1, 2),
|
|
|
|
scatter_dims_to_operand_dims=(0, 1,
|
|
|
|
2))
|
2024-12-20 11:44:02 +01:00
|
|
|
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
|
|
|
|
kernel.reshape([-1, 8]), dnums)
|
2024-07-09 14:54:34 -04:00
|
|
|
return mesh
|
|
|
|
|
2022-02-13 21:36:03 +01:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
@partial(jax.jit, static_argnums=(3, 4))
|
|
|
|
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
|
|
|
|
|
|
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
|
|
|
|
|
|
|
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
|
|
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
|
|
|
|
|
|
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
|
|
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
|
|
|
grid_mesh = autoshmap(_cic_paint_impl,
|
|
|
|
gpu_mesh=gpu_mesh,
|
|
|
|
in_specs=(spec, spec, P()),
|
|
|
|
out_specs=spec)(grid_mesh, positions, weight)
|
|
|
|
grid_mesh = halo_exchange(grid_mesh,
|
|
|
|
halo_extents=halo_extents,
|
|
|
|
halo_periods=(True, True))
|
|
|
|
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
|
|
|
|
|
|
|
return grid_mesh
|
|
|
|
|
|
|
|
|
|
|
|
def _cic_read_impl(grid_mesh, positions):
|
2024-07-09 14:54:34 -04:00
|
|
|
""" Paints positions onto mesh
|
2024-12-20 11:44:02 +01:00
|
|
|
mesh: [nx, ny, nz]
|
|
|
|
positions: [nx,ny,nz, 3]
|
|
|
|
"""
|
|
|
|
# Save original shape for reshaping output later
|
|
|
|
original_shape = positions.shape
|
|
|
|
# Reshape positions to a flat list of 3D coordinates
|
|
|
|
positions = positions.reshape([-1, 3])
|
|
|
|
# Expand dimensions to calculate neighbor coordinates
|
2024-07-09 14:54:34 -04:00
|
|
|
positions = jnp.expand_dims(positions, 1)
|
2024-12-20 11:44:02 +01:00
|
|
|
# Floor the positions to get the base grid cell for each particle
|
2024-07-09 14:54:34 -04:00
|
|
|
floor = jnp.floor(positions)
|
2024-12-20 11:44:02 +01:00
|
|
|
# Define connections to calculate all neighbor coordinates
|
2024-07-09 14:54:34 -04:00
|
|
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
|
|
|
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
2024-12-20 11:44:02 +01:00
|
|
|
# Calculate the 8 neighboring coordinates
|
2024-07-09 14:54:34 -04:00
|
|
|
neighboor_coords = floor + connection
|
2024-12-20 11:44:02 +01:00
|
|
|
# Calculate kernel weights based on distance from each neighboring coordinate
|
2024-07-09 14:54:34 -04:00
|
|
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
|
|
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
2024-12-20 11:44:02 +01:00
|
|
|
# Modulo operation to wrap around edges if necessary
|
2024-07-09 14:54:34 -04:00
|
|
|
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
2024-12-20 11:44:02 +01:00
|
|
|
jnp.array(grid_mesh.shape))
|
|
|
|
# Ensure grid_mesh shape is as expected
|
|
|
|
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
|
|
|
return (grid_mesh[neighboor_coords[..., 0],
|
|
|
|
neighboor_coords[..., 1],
|
|
|
|
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(2, 3))
|
|
|
|
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
|
|
|
|
|
|
|
original_shape = positions.shape
|
|
|
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
2022-02-13 21:36:03 +01:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
|
|
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
|
|
|
grid_mesh = halo_exchange(grid_mesh,
|
|
|
|
halo_extents=halo_extents,
|
|
|
|
halo_periods=(True, True))
|
|
|
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
|
|
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
|
|
|
|
|
|
|
displacement = autoshmap(_cic_read_impl,
|
|
|
|
gpu_mesh=gpu_mesh,
|
|
|
|
in_specs=(spec, spec),
|
|
|
|
out_specs=spec)(grid_mesh, positions)
|
|
|
|
|
|
|
|
return displacement.reshape(original_shape[:-1])
|
2022-02-13 21:36:03 +01:00
|
|
|
|
2022-03-26 00:06:34 +01:00
|
|
|
|
2022-05-17 23:37:55 +02:00
|
|
|
def cic_paint_2d(mesh, positions, weight):
|
2024-07-09 14:54:34 -04:00
|
|
|
""" Paints positions onto a 2d mesh
|
2024-12-20 11:44:02 +01:00
|
|
|
mesh: [nx, ny]
|
|
|
|
positions: [npart, 2]
|
|
|
|
weight: [npart]
|
|
|
|
"""
|
2024-07-09 14:54:34 -04:00
|
|
|
positions = jnp.expand_dims(positions, 1)
|
|
|
|
floor = jnp.floor(positions)
|
|
|
|
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
|
|
|
|
|
|
|
neighboor_coords = floor + connection
|
|
|
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
|
|
|
kernel = kernel[..., 0] * kernel[..., 1]
|
|
|
|
if weight is not None:
|
|
|
|
kernel = kernel * weight[..., jnp.newaxis]
|
|
|
|
|
|
|
|
neighboor_coords = jnp.mod(
|
|
|
|
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
|
|
|
jnp.array(mesh.shape))
|
|
|
|
|
|
|
|
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
|
|
|
inserted_window_dims=(0, 1),
|
|
|
|
scatter_dims_to_operand_dims=(0,
|
|
|
|
1))
|
|
|
|
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]),
|
|
|
|
dnums)
|
|
|
|
return mesh
|
|
|
|
|
2022-05-17 11:19:56 +02:00
|
|
|
|
2024-12-20 11:44:02 +01:00
|
|
|
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|
|
|
|
|
|
|
halo_x, _ = halo_size[0]
|
|
|
|
halo_y, _ = halo_size[1]
|
|
|
|
|
|
|
|
original_shape = displacements.shape
|
|
|
|
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
|
|
|
if not jnp.isscalar(weight):
|
|
|
|
if weight.shape != original_shape[:-1]:
|
|
|
|
raise ValueError("Weight shape must match particle shape")
|
|
|
|
else:
|
|
|
|
weight = weight.flatten()
|
|
|
|
# Padding is forced to be zero in a single gpu run
|
|
|
|
|
|
|
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
|
|
|
jnp.arange(particle_mesh.shape[1]),
|
|
|
|
jnp.arange(particle_mesh.shape[2]),
|
|
|
|
indexing='ij')
|
|
|
|
|
|
|
|
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
|
|
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
|
|
|
return scatter(pmid.reshape([-1, 3]),
|
|
|
|
displacements.reshape([-1, 3]),
|
|
|
|
particle_mesh,
|
|
|
|
chunk_size=2**24,
|
|
|
|
val=weight)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(1, 2, 4))
|
|
|
|
def cic_paint_dx(displacements,
|
|
|
|
halo_size=0,
|
|
|
|
sharding=None,
|
|
|
|
weight=1.0,
|
|
|
|
chunk_size=2**24):
|
|
|
|
|
|
|
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
|
|
|
|
|
|
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
|
|
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
|
|
|
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
|
|
|
halo_size=halo_size,
|
|
|
|
weight=weight,
|
|
|
|
chunk_size=chunk_size),
|
|
|
|
gpu_mesh=gpu_mesh,
|
|
|
|
in_specs=spec,
|
|
|
|
out_specs=spec)(displacements)
|
|
|
|
|
|
|
|
grid_mesh = halo_exchange(grid_mesh,
|
|
|
|
halo_extents=halo_extents,
|
|
|
|
halo_periods=(True, True))
|
|
|
|
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
|
|
|
return grid_mesh
|
|
|
|
|
|
|
|
|
|
|
|
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|
|
|
|
|
|
|
halo_x, _ = halo_size[0]
|
|
|
|
halo_y, _ = halo_size[1]
|
|
|
|
|
|
|
|
original_shape = [
|
|
|
|
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
|
|
|
|
]
|
|
|
|
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
|
|
|
jnp.arange(original_shape[1]),
|
|
|
|
jnp.arange(original_shape[2]),
|
|
|
|
indexing='ij')
|
|
|
|
|
|
|
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
|
|
|
|
|
|
|
pmid = pmid.reshape([-1, 3])
|
|
|
|
disp = disp.reshape([-1, 3])
|
|
|
|
|
|
|
|
return gather(pmid, disp, grid_mesh).reshape(original_shape)
|
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(2, 3))
|
|
|
|
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
|
|
|
|
|
|
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
|
|
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
|
|
|
grid_mesh = halo_exchange(grid_mesh,
|
|
|
|
halo_extents=halo_extents,
|
|
|
|
halo_periods=(True, True))
|
|
|
|
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
|
|
|
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
|
|
|
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
|
|
|
|
gpu_mesh=gpu_mesh,
|
|
|
|
in_specs=(spec),
|
|
|
|
out_specs=spec)(grid_mesh, disp)
|
|
|
|
|
|
|
|
return displacements
|
|
|
|
|
|
|
|
|
2022-03-26 00:06:34 +01:00
|
|
|
def compensate_cic(field):
|
2024-07-09 14:54:34 -04:00
|
|
|
"""
|
2024-12-20 11:44:02 +01:00
|
|
|
Compensate for CiC painting
|
|
|
|
Args:
|
|
|
|
field: input 3D cic-painted field
|
|
|
|
Returns:
|
|
|
|
compensated_field
|
|
|
|
"""
|
|
|
|
delta_k = fft3d(field)
|
|
|
|
|
|
|
|
kvec = fftk(delta_k)
|
2024-07-09 14:54:34 -04:00
|
|
|
delta_k = cic_compensation(kvec) * delta_k
|
2024-12-20 11:44:02 +01:00
|
|
|
return ifft3d(delta_k)
|