Implement Ops

This commit is contained in:
Wassim KABALAN 2024-07-08 00:23:35 +02:00
parent e708f5b176
commit bc2612a198
2 changed files with 785 additions and 0 deletions

226
jaxpm/_src/base_ops.py Normal file
View file

@ -0,0 +1,226 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from functools import partial
from inspect import signature
import jax
import jax.numpy as jnp
import jaxdecomp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomp import halo_exchange
from jaxdecomp.fft import pfft3d, pifft3d
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
ShardedOperator, register_operator)
class FFTOperator(CustomPartionedOperator):
name = 'fftn'
def single_gpu_impl(x):
return jnp.fft.fftn(x)
def multi_gpu_impl(x):
return pfft3d(x)
class IFFTOperator(CustomPartionedOperator):
name = 'ifftn'
def single_gpu_impl(x):
return jnp.fft.ifftn(x)
def multi_gpu_impl(x):
return pifft3d(x)
class HaloExchangeOperator(CustomPartionedOperator):
name = 'halo_exchange'
# Halo exchange does nothing on a single GPU
# Inside a jit , this will be optimized out
def single_gpu_impl(x):
return x
def multi_gpu_impl(x):
return halo_exchange(x)
# Padding and unpadding operators should not do anything in case of single GPU
# Since there is no halo exchange for a single GPU
# Inside a jit , this will be optimized out
class PaddingOperator(ShardedOperator):
name = 'slice_pad'
def single_gpu_impl(x, pad_width):
return x
def multi_gpu_impl(x, pad_width):
return jnp.pad(x, pad_width)
def infer_sharding_from_base_sharding(base_sharding):
in_spec = base_sharding, P()
out_spec = base_sharding
return in_spec, out_spec
class UnpaddingOperator(ShardedOperator):
name = 'slice_unpad'
def single_gpu_impl(x, pad_width):
return x
def multi_gpu_impl(x, pad_width):
# WARNING : unequal halo size is not supported
halo_x, _ = pad_width[0]
halo_y, _ = pad_width[0]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
return x[halo_x:-halo_x, halo_y:-halo_y, :]
def infer_sharding_from_base_sharding(base_sharding):
in_spec = base_sharding, P()
out_spec = base_sharding
return in_spec, out_spec
class NormalFieldOperator(CallBackOperator):
name = 'normal'
def single_gpu_impl(shape, key, dtype='float32'):
return jax.random.normal(key, shape, dtype=dtype)
def multi_gpu_impl(shape, key, dtype='float32', base_sharding=None):
assert (isinstance(base_sharding, NamedSharding))
sharding = NormalFieldOperator.shardings_to_use_in_impl(base_sharding)
def get_axis_size(sharding, index):
axis_name = sharding.spec[index]
if axis_name == None:
return 1
else:
return sharding.mesh.shape[sharding.spec[index]]
pdims = [get_axis_size(sharding, i) for i in range(2)]
local_mesh_shape = [
shape[0] // pdims[1], shape[1] // pdims[0], shape[2]
]
return jax.make_array_from_single_device_arrays(
shape=shape,
sharding=sharding,
arrays=[jax.random.normal(key, local_mesh_shape, dtype=dtype)])
def shardings_to_use_in_impl(base_sharding):
return base_sharding
class FFTKOperator(CallBackOperator):
name = 'fftk'
def single_gpu_impl(shape, symmetric=True, finite=False, dtype=np.float32):
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
k.append(kd.astype(dtype))
del kd, kdshape
return k
def multi_gpu_impl(shape,
symmetric=True,
finite=False,
dtype=np.float32,
base_sharding=None):
assert (isinstance(base_sharding, NamedSharding))
kvec = FFTKOperator.single_gpu_impl(shape, symmetric, finite, dtype)
z_sharding, y_sharding = FFTKOperator.shardings_to_use_in_impl(shape)
return [
jax.make_array_from_callback(
(shape[0], 1, 1),
sharding=z_sharding,
data_callback=lambda x: kvec[0].reshape([-1, 1, 1])[x]),
jax.make_array_from_callback(
(1, shape[1], 1),
sharding=y_sharding,
data_callback=lambda x: kvec[1].reshape([1, -1, 1])[x]),
kvec[2].reshape([1, 1, -1])
]
@staticmethod
def shardings_to_use_in_impl(base_sharding):
spec = base_sharding.spec
z_sharding = NamedSharding(P(spec[0], None, None))
y_sharding = NamedSharding(P(None, spec[1], None))
return z_sharding, y_sharding
class GenerateParticlesOperator(CallBackOperator):
name = 'generate_initial_positions'
def single_gpu_impl(shape):
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in shape]),
axis=-1)
def multi_gpu_impl(shape, base_sharding=None):
assert (isinstance(base_sharding, NamedSharding))
sharding = GenerateParticlesOperator.shardings_to_use_in_impl(
base_sharding)
return jax.make_array_from_callback(
shape=tuple([*shape, 3]),
sharding=sharding,
data_callback=lambda x: jnp.stack(jnp.meshgrid(
jnp.arange(shape[0])[x[0]],
jnp.arange(shape[1])[x[1]],
jnp.arange(shape[2]),
indexing='ij'),
axis=-1))
def shardings_to_use_in_impl(base_sharding):
return base_sharding
register_operator(FFTOperator)
register_operator(IFFTOperator)
register_operator(HaloExchangeOperator)
register_operator(PaddingOperator)
register_operator(UnpaddingOperator)
register_operator(NormalFieldOperator)
register_operator(FFTKOperator)
register_operator(GenerateParticlesOperator)

559
jaxpm/_src/painting_ops.py Normal file
View file

@ -0,0 +1,559 @@
from functools import partial
import jax
import jax.numpy as jnp
from jax import lax
from jax.lax import scan
from jaxdecomp import halo_exchange
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
ShardedOperator, register_operator)
from jaxpm.ops import slice_pad, slice_unpad
class CICPaintOperator(ShardedOperator):
name = 'cic_paint'
def single_gpu_impl(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=0):
del halo_size
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords_mod = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(particle_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod,
kernel.reshape([-1, 8]), dnums)
return particle_mesh
def multi_gpu_impl(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=8,
__aux_input=None):
rank = jax.process_index()
correct_y = -particle_mesh.shape[1] * (rank // __aux_input[0])
correct_z = -particle_mesh.shape[0] * (rank % __aux_input[1])
# Get positions relative to the start of each slice
positions = positions.at[:, :, :, 1].add(correct_y)
positions = positions.at[:, :, :, 0].add(correct_z)
positions = positions.reshape([-1, 3])
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
halo_start = [0, halo_size, 0]
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
halo_start = [halo_size, 0, 0]
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
halo_start = [halo_size, halo_size, 0]
particle_mesh = jnp.pad(particle_mesh, halo_width)
positions += jnp.array(halo_start).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords_mod = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(particle_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod,
kernel.reshape([-1, 8]), dnums)
return particle_mesh, halo_size
def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None):
if __aux_input[0] == 1:
halo_width = (0, halo_size, 0)
halo_extents = (0, halo_size // 2, 0)
elif __aux_input[1] == 1:
halo_width = (halo_size, 0, 0)
halo_extents = (halo_size // 2, 0, 0)
else:
halo_width = (halo_size, halo_size, 0)
halo_extents = (halo_size // 2, halo_size // 2, 0)
particle_mesh = halo_exchange(particle_mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width)
return particle_mesh
def get_aux_input_from_base_sharding(base_sharding):
def get_axis_size(sharding, index):
axis_name = sharding.spec[index]
if axis_name == None:
return 1
else:
return sharding.mesh.shape[sharding.spec[index]]
return [get_axis_size(base_sharding, i) for i in range(2)]
class CICReadOperator(ShardedOperator):
name = 'cic_read'
def single_gpu_impl(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=0):
del halo_size
original_shape = positions.shape
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(particle_mesh.shape))
particles = (
particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return particles.reshape(original_shape)
def multi_gpu_prolog(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=0,
__aux_input=None):
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
halo_extents = (0, halo_size // 2, 0)
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
halo_extents = (halo_size // 2, 0, 0)
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
halo_extents = (halo_size // 2, halo_size // 2, 0)
particle_mesh = slice_pad(particle_mesh, pad_width=halo_width)
particle_mesh = halo_exchange(particle_mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
return particle_mesh, positions, halo_size
def multi_gpu_impl(particle_mesh: jnp.ndarray,
positions: jnp.ndarray,
halo_size=0,
__aux_input=None):
original_shape = positions.shape
positions = positions.reshape([-1, 3])
if __aux_input[0] == 1:
halo_start = [0, halo_size, 0]
elif __aux_input[1] == 1:
halo_start = [halo_size, 0, 0]
else:
halo_start = [halo_size, halo_size, 0]
positions += jnp.array(halo_start).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1],
[1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(particle_mesh.shape))
particles = (
particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return particles.reshape(original_shape)
def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [
x.reshape(chunk_num, chunk_size, *x.shape[1:])
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
]
return remainder, chunks
def enmesh(i1, d1, a1, s1, b12, a2, s2):
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
if s1 is not None:
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
jnp.arange(dim, dtype=i1.dtype)) & 1
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
if s1 is not None:
L = s1 * a1
i2 %= L
i2 //= a2
d2 = P - i2 * a2
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
d2 /= a2
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
if s1 is not None:
i2 %= s1
f2 = 1 - jnp.abs(d2)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
f2 = f2.prod(axis=-1)
return i2, f2
def _scatter_chunk(carry, chunk):
mesh_shape , mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(val * frac)
carry = mesh, offset, cell_size
return carry, None
def scatter(pmid,
disp,
mesh,
chunk_size=2**24,
val=1.,
offset=0,
cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh.shape , mesh, offset, cell_size
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
return mesh
def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None):
"""Gather particle values from mesh multilinearly in n-D.
Parameters
----------
ptcl : Particles
conf : Configuration
mesh : ArrayLike
Input mesh.
val : ArrayLike, optional
Input values, can be 0D.
offset : ArrayLike, optional
Offset of mesh to particle grid. If 0D, the value is used in each dimension.
cell_size : float, optional
Mesh cell size in [L]. Default is ``conf.cell_size``.
Returns
-------
val : jax.Array
Output values.
"""
return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size)
def _chunk_cat(remainder_array, chunked_array):
"""Reshape and concatenate one remainder and one chunked particle arrays."""
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
if remainder_array is not None:
array = jnp.concatenate((remainder_array, array), axis=0)
return array
def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None):
ptcl_num, spatial_ndim = pmid.shape
mesh = jnp.asarray(mesh)
val = jnp.asarray(val)
if mesh.shape[spatial_ndim:] != val.shape[1:]:
raise ValueError('channel shape mismatch: '
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp,
val)
carry = mesh.shape , mesh, offset, cell_size
val_0 = None
if remainder is not None:
val_0 = _gather_chunk(carry, remainder)[1]
val = scan(_gather_chunk, carry, chunks)[1]
val = _chunk_cat(val_0, val)
return val
def _gather_chunk(carry, chunk):
mesh_shape , mesh, offset, cell_size = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape[:spatial_ndim]
chan_ndim = mesh.ndim - spatial_ndim
chan_axis = tuple(range(-chan_ndim, 0))
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset,
cell_size, spatial_shape, False)
# gather
ind = tuple(ind[..., i] for i in range(spatial_ndim))
frac = jnp.expand_dims(frac, chan_axis)
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
return carry, val
class CICPaintDXOperator(ShardedOperator):
name = 'cic_paint_dx'
def single_gpu_impl(displacement, halo_size=0):
del halo_size
original_shape = displacement.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
pmid = jnp.stack([b, a, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh)
def multi_gpu_impl(displacement, halo_size=0, __aux_input=None):
original_shape = displacement.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
particle_mesh = jnp.pad(particle_mesh, halo_width)
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return scatter(pmid, displacement.reshape([-1, 3]),
particle_mesh), halo_size
def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None):
if __aux_input[0] == 1:
halo_width = (0, halo_size, 0)
halo_extents = (0, halo_size // 2, 0)
elif __aux_input[1] == 1:
halo_width = (halo_size, 0, 0)
halo_extents = (halo_size // 2, 0, 0)
else:
halo_width = (halo_size, halo_size, 0)
halo_extents = (halo_size // 2, halo_size // 2, 0)
particle_mesh = halo_exchange(particle_mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width)
return particle_mesh
class CICReadDXOperator(ShardedOperator):
name = 'cic_read_dx'
def single_gpu_impl(particle_mesh, halo_size=0):
del halo_size
original_shape = (*particle_mesh.shape, 3)
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
pmid = jnp.stack([b, a, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None):
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
halo_extents = (0, halo_size // 2, 0)
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
halo_extents = (halo_size // 2, 0, 0)
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
halo_extents = (halo_size // 2, halo_size // 2, 0)
particle_mesh = slice_pad(particle_mesh, pad_width=halo_width)
particle_mesh = halo_exchange(particle_mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
return particle_mesh, halo_size
def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None):
original_shape = (*particle_mesh.shape, 3)
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
particle_mesh = jnp.pad(particle_mesh, halo_width)
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
pmid = pmid.reshape([-1, 3])
# TODO must be reshaped
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size
register_operator(CICPaintOperator)
register_operator(CICReadOperator)
register_operator(CICPaintDXOperator)