JaxPM/jaxpm/_src/painting_ops.py
2024-07-08 00:23:35 +02:00

559 lines
19 KiB
Python

from functools import partial
import jax
import jax.numpy as jnp
from jax import lax
from jax.lax import scan
from jaxdecomp import halo_exchange
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
ShardedOperator, register_operator)
from jaxpm.ops import slice_pad, slice_unpad
class CICPaintOperator(ShardedOperator):
name = 'cic_paint'
def single_gpu_impl(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=0):
del halo_size
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords_mod = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(particle_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod,
kernel.reshape([-1, 8]), dnums)
return particle_mesh
def multi_gpu_impl(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=8,
__aux_input=None):
rank = jax.process_index()
correct_y = -particle_mesh.shape[1] * (rank // __aux_input[0])
correct_z = -particle_mesh.shape[0] * (rank % __aux_input[1])
# Get positions relative to the start of each slice
positions = positions.at[:, :, :, 1].add(correct_y)
positions = positions.at[:, :, :, 0].add(correct_z)
positions = positions.reshape([-1, 3])
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
halo_start = [0, halo_size, 0]
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
halo_start = [halo_size, 0, 0]
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
halo_start = [halo_size, halo_size, 0]
particle_mesh = jnp.pad(particle_mesh, halo_width)
positions += jnp.array(halo_start).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords_mod = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(particle_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod,
kernel.reshape([-1, 8]), dnums)
return particle_mesh, halo_size
def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None):
if __aux_input[0] == 1:
halo_width = (0, halo_size, 0)
halo_extents = (0, halo_size // 2, 0)
elif __aux_input[1] == 1:
halo_width = (halo_size, 0, 0)
halo_extents = (halo_size // 2, 0, 0)
else:
halo_width = (halo_size, halo_size, 0)
halo_extents = (halo_size // 2, halo_size // 2, 0)
particle_mesh = halo_exchange(particle_mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width)
return particle_mesh
def get_aux_input_from_base_sharding(base_sharding):
def get_axis_size(sharding, index):
axis_name = sharding.spec[index]
if axis_name == None:
return 1
else:
return sharding.mesh.shape[sharding.spec[index]]
return [get_axis_size(base_sharding, i) for i in range(2)]
class CICReadOperator(ShardedOperator):
name = 'cic_read'
def single_gpu_impl(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=0):
del halo_size
original_shape = positions.shape
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(particle_mesh.shape))
particles = (
particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return particles.reshape(original_shape)
def multi_gpu_prolog(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=0,
__aux_input=None):
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
halo_extents = (0, halo_size // 2, 0)
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
halo_extents = (halo_size // 2, 0, 0)
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
halo_extents = (halo_size // 2, halo_size // 2, 0)
particle_mesh = slice_pad(particle_mesh, pad_width=halo_width)
particle_mesh = halo_exchange(particle_mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
return particle_mesh, positions, halo_size
def multi_gpu_impl(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=0,
__aux_input=None):
original_shape = positions.shape
positions = positions.reshape([-1, 3])
if __aux_input[0] == 1:
halo_start = [0, halo_size, 0]
elif __aux_input[1] == 1:
halo_start = [halo_size, 0, 0]
else:
halo_start = [halo_size, halo_size, 0]
positions += jnp.array(halo_start).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(particle_mesh.shape))
particles = (
particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return particles.reshape(original_shape)
def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [
x.reshape(chunk_num, chunk_size, *x.shape[1:])
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
]
return remainder, chunks
def enmesh(i1, d1, a1, s1, b12, a2, s2):
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
if s1 is not None:
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
jnp.arange(dim, dtype=i1.dtype)) & 1
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
if s1 is not None:
L = s1 * a1
i2 %= L
i2 //= a2
d2 = P - i2 * a2
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
d2 /= a2
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
if s1 is not None:
i2 %= s1
f2 = 1 - jnp.abs(d2)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
f2 = f2.prod(axis=-1)
return i2, f2
def _scatter_chunk(carry, chunk):
mesh_shape , mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(val * frac)
carry = mesh, offset, cell_size
return carry, None
def scatter(pmid,
disp,
mesh,
chunk_size=2**24,
val=1.,
offset=0,
cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh.shape , mesh, offset, cell_size
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
return mesh
def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None):
"""Gather particle values from mesh multilinearly in n-D.
Parameters
----------
ptcl : Particles
conf : Configuration
mesh : ArrayLike
Input mesh.
val : ArrayLike, optional
Input values, can be 0D.
offset : ArrayLike, optional
Offset of mesh to particle grid. If 0D, the value is used in each dimension.
cell_size : float, optional
Mesh cell size in [L]. Default is ``conf.cell_size``.
Returns
-------
val : jax.Array
Output values.
"""
return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size)
def _chunk_cat(remainder_array, chunked_array):
"""Reshape and concatenate one remainder and one chunked particle arrays."""
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
if remainder_array is not None:
array = jnp.concatenate((remainder_array, array), axis=0)
return array
def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None):
ptcl_num, spatial_ndim = pmid.shape
mesh = jnp.asarray(mesh)
val = jnp.asarray(val)
if mesh.shape[spatial_ndim:] != val.shape[1:]:
raise ValueError('channel shape mismatch: '
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp,
val)
carry = mesh.shape , mesh, offset, cell_size
val_0 = None
if remainder is not None:
val_0 = _gather_chunk(carry, remainder)[1]
val = scan(_gather_chunk, carry, chunks)[1]
val = _chunk_cat(val_0, val)
return val
def _gather_chunk(carry, chunk):
mesh_shape , mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape[:spatial_ndim]
chan_ndim = mesh.ndim - spatial_ndim
chan_axis = tuple(range(-chan_ndim, 0))
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset,
cell_size, spatial_shape, False)
# gather
ind = tuple(ind[..., i] for i in range(spatial_ndim))
frac = jnp.expand_dims(frac, chan_axis)
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
return carry, val
class CICPaintDXOperator(ShardedOperator):
name = 'cic_paint_dx'
def single_gpu_impl(displacement, halo_size=0):
del halo_size
original_shape = displacement.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
pmid = jnp.stack([b, a, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh)
def multi_gpu_impl(displacement, halo_size=0, __aux_input=None):
original_shape = displacement.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
particle_mesh = jnp.pad(particle_mesh, halo_width)
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return scatter(pmid, displacement.reshape([-1, 3]),
particle_mesh), halo_size
def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None):
if __aux_input[0] == 1:
halo_width = (0, halo_size, 0)
halo_extents = (0, halo_size // 2, 0)
elif __aux_input[1] == 1:
halo_width = (halo_size, 0, 0)
halo_extents = (halo_size // 2, 0, 0)
else:
halo_width = (halo_size, halo_size, 0)
halo_extents = (halo_size // 2, halo_size // 2, 0)
particle_mesh = halo_exchange(particle_mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width)
return particle_mesh
class CICReadDXOperator(ShardedOperator):
name = 'cic_read_dx'
def single_gpu_impl(particle_mesh, halo_size=0):
del halo_size
original_shape = (*particle_mesh.shape, 3)
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
pmid = jnp.stack([b, a, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None):
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
halo_extents = (0, halo_size // 2, 0)
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
halo_extents = (halo_size // 2, 0, 0)
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
halo_extents = (halo_size // 2, halo_size // 2, 0)
particle_mesh = slice_pad(particle_mesh, pad_width=halo_width)
particle_mesh = halo_exchange(particle_mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
return particle_mesh, halo_size
def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None):
original_shape = (*particle_mesh.shape, 3)
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
particle_mesh = jnp.pad(particle_mesh, halo_width)
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
pmid = pmid.reshape([-1, 3])
# TODO must be reshaped
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size
register_operator(CICPaintOperator)
register_operator(CICReadOperator)
register_operator(CICPaintDXOperator)