From bc2612a1987e25be02ede13f265ac72e4d39ffb3 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 8 Jul 2024 00:23:35 +0200 Subject: [PATCH] Implement Ops --- jaxpm/_src/base_ops.py | 226 +++++++++++++++ jaxpm/_src/painting_ops.py | 559 +++++++++++++++++++++++++++++++++++++ 2 files changed, 785 insertions(+) create mode 100644 jaxpm/_src/base_ops.py create mode 100644 jaxpm/_src/painting_ops.py diff --git a/jaxpm/_src/base_ops.py b/jaxpm/_src/base_ops.py new file mode 100644 index 0000000..1c97557 --- /dev/null +++ b/jaxpm/_src/base_ops.py @@ -0,0 +1,226 @@ +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass +from functools import partial +from inspect import signature + +import jax +import jax.numpy as jnp +import jaxdecomp +import numpy as np +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from jaxdecomp import halo_exchange +from jaxdecomp.fft import pfft3d, pifft3d + +from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator, + ShardedOperator, register_operator) + + +class FFTOperator(CustomPartionedOperator): + + name = 'fftn' + + def single_gpu_impl(x): + return jnp.fft.fftn(x) + + def multi_gpu_impl(x): + return pfft3d(x) + + +class IFFTOperator(CustomPartionedOperator): + + name = 'ifftn' + + def single_gpu_impl(x): + return jnp.fft.ifftn(x) + + def multi_gpu_impl(x): + return pifft3d(x) + + +class HaloExchangeOperator(CustomPartionedOperator): + + name = 'halo_exchange' + + # Halo exchange does nothing on a single GPU + # Inside a jit , this will be optimized out + def single_gpu_impl(x): + return x + + def multi_gpu_impl(x): + return halo_exchange(x) + + +# Padding and unpadding operators should not do anything in case of single GPU +# Since there is no halo exchange for a single GPU +# Inside a jit , this will be optimized out + + +class PaddingOperator(ShardedOperator): + + name = 'slice_pad' + + def single_gpu_impl(x, pad_width): + return x + + def multi_gpu_impl(x, pad_width): + return jnp.pad(x, pad_width) + + def infer_sharding_from_base_sharding(base_sharding): + + in_spec = base_sharding, P() + out_spec = base_sharding + + return in_spec, out_spec + + +class UnpaddingOperator(ShardedOperator): + + name = 'slice_unpad' + + def single_gpu_impl(x, pad_width): + return x + + def multi_gpu_impl(x, pad_width): + + # WARNING : unequal halo size is not supported + halo_x, _ = pad_width[0] + halo_y, _ = pad_width[0] + + # Apply corrections along x + x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2]) + x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:]) + # Apply corrections along y + x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2]) + x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:]) + + return x[halo_x:-halo_x, halo_y:-halo_y, :] + + def infer_sharding_from_base_sharding(base_sharding): + + in_spec = base_sharding, P() + out_spec = base_sharding + + return in_spec, out_spec + + +class NormalFieldOperator(CallBackOperator): + + name = 'normal' + + def single_gpu_impl(shape, key, dtype='float32'): + return jax.random.normal(key, shape, dtype=dtype) + + def multi_gpu_impl(shape, key, dtype='float32', base_sharding=None): + + assert (isinstance(base_sharding, NamedSharding)) + sharding = NormalFieldOperator.shardings_to_use_in_impl(base_sharding) + + def get_axis_size(sharding, index): + axis_name = sharding.spec[index] + if axis_name == None: + return 1 + else: + return sharding.mesh.shape[sharding.spec[index]] + + pdims = [get_axis_size(sharding, i) for i in range(2)] + local_mesh_shape = [ + shape[0] // pdims[1], shape[1] // pdims[0], shape[2] + ] + + return jax.make_array_from_single_device_arrays( + shape=shape, + sharding=sharding, + arrays=[jax.random.normal(key, local_mesh_shape, dtype=dtype)]) + + def shardings_to_use_in_impl(base_sharding): + return base_sharding + + +class FFTKOperator(CallBackOperator): + name = 'fftk' + + def single_gpu_impl(shape, symmetric=True, finite=False, dtype=np.float32): + k = [] + for d in range(len(shape)): + kd = np.fft.fftfreq(shape[d]) + kd *= 2 * np.pi + kdshape = np.ones(len(shape), dtype='int') + if symmetric and d == len(shape) - 1: + kd = kd[:shape[d] // 2 + 1] + kdshape[d] = len(kd) + kd = kd.reshape(kdshape) + + k.append(kd.astype(dtype)) + del kd, kdshape + return k + + def multi_gpu_impl(shape, + symmetric=True, + finite=False, + dtype=np.float32, + base_sharding=None): + + assert (isinstance(base_sharding, NamedSharding)) + kvec = FFTKOperator.single_gpu_impl(shape, symmetric, finite, dtype) + + z_sharding, y_sharding = FFTKOperator.shardings_to_use_in_impl(shape) + + return [ + jax.make_array_from_callback( + (shape[0], 1, 1), + sharding=z_sharding, + data_callback=lambda x: kvec[0].reshape([-1, 1, 1])[x]), + jax.make_array_from_callback( + (1, shape[1], 1), + sharding=y_sharding, + data_callback=lambda x: kvec[1].reshape([1, -1, 1])[x]), + kvec[2].reshape([1, 1, -1]) + ] + + @staticmethod + def shardings_to_use_in_impl(base_sharding): + spec = base_sharding.spec + + z_sharding = NamedSharding(P(spec[0], None, None)) + y_sharding = NamedSharding(P(None, spec[1], None)) + + return z_sharding, y_sharding + + +class GenerateParticlesOperator(CallBackOperator): + + name = 'generate_initial_positions' + + def single_gpu_impl(shape): + return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in shape]), + axis=-1) + + def multi_gpu_impl(shape, base_sharding=None): + assert (isinstance(base_sharding, NamedSharding)) + sharding = GenerateParticlesOperator.shardings_to_use_in_impl( + base_sharding) + + return jax.make_array_from_callback( + shape=tuple([*shape, 3]), + sharding=sharding, + data_callback=lambda x: jnp.stack(jnp.meshgrid( + jnp.arange(shape[0])[x[0]], + jnp.arange(shape[1])[x[1]], + jnp.arange(shape[2]), + indexing='ij'), + axis=-1)) + + def shardings_to_use_in_impl(base_sharding): + return base_sharding + + +register_operator(FFTOperator) +register_operator(IFFTOperator) +register_operator(HaloExchangeOperator) +register_operator(PaddingOperator) +register_operator(UnpaddingOperator) +register_operator(NormalFieldOperator) +register_operator(FFTKOperator) +register_operator(GenerateParticlesOperator) diff --git a/jaxpm/_src/painting_ops.py b/jaxpm/_src/painting_ops.py new file mode 100644 index 0000000..7556d37 --- /dev/null +++ b/jaxpm/_src/painting_ops.py @@ -0,0 +1,559 @@ +from functools import partial + +import jax +import jax.numpy as jnp +from jax import lax +from jax.lax import scan +from jaxdecomp import halo_exchange + +from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator, + ShardedOperator, register_operator) +from jaxpm.ops import slice_pad, slice_unpad + + +class CICPaintOperator(ShardedOperator): + + name = 'cic_paint' + + def single_gpu_impl(particle_mesh: jnp.ndarray, + positions: jnp.ndarray, + halo_size=0): + + del halo_size + + positions = positions.reshape([-1, 3]) + + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], + [1., 1, 0], [1., 0, 1], [0., 1, 1], + [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords_mod = jnp.mod( + neighboor_coords.reshape([-1, 8, 3]).astype('int32'), + jnp.array(particle_mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1, 2), + scatter_dims_to_operand_dims=(0, 1, 2)) + particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod, + kernel.reshape([-1, 8]), dnums) + + return particle_mesh + + def multi_gpu_impl(particle_mesh: jnp.ndarray, + positions: jnp.ndarray, + halo_size=8, + __aux_input=None): + + rank = jax.process_index() + correct_y = -particle_mesh.shape[1] * (rank // __aux_input[0]) + correct_z = -particle_mesh.shape[0] * (rank % __aux_input[1]) + # Get positions relative to the start of each slice + positions = positions.at[:, :, :, 1].add(correct_y) + positions = positions.at[:, :, :, 0].add(correct_z) + positions = positions.reshape([-1, 3]) + + halo_tuple = (halo_size, halo_size) + if __aux_input[0] == 1: + halo_width = ((0, 0), halo_tuple, (0, 0)) + halo_start = [0, halo_size, 0] + elif __aux_input[1] == 1: + halo_width = (halo_tuple, (0, 0), (0, 0)) + halo_start = [halo_size, 0, 0] + else: + halo_width = (halo_tuple, halo_tuple, (0, 0)) + halo_start = [halo_size, halo_size, 0] + + particle_mesh = jnp.pad(particle_mesh, halo_width) + positions += jnp.array(halo_start).reshape([-1, 3]) + + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], + [1., 1, 0], [1., 0, 1], [0., 1, 1], + [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords_mod = jnp.mod( + neighboor_coords.reshape([-1, 8, 3]).astype('int32'), + jnp.array(particle_mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1, 2), + scatter_dims_to_operand_dims=(0, 1, 2)) + particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod, + kernel.reshape([-1, 8]), dnums) + + return particle_mesh, halo_size + + def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None): + + if __aux_input[0] == 1: + halo_width = (0, halo_size, 0) + halo_extents = (0, halo_size // 2, 0) + elif __aux_input[1] == 1: + halo_width = (halo_size, 0, 0) + halo_extents = (halo_size // 2, 0, 0) + else: + halo_width = (halo_size, halo_size, 0) + halo_extents = (halo_size // 2, halo_size // 2, 0) + + particle_mesh = halo_exchange(particle_mesh, + halo_extents=halo_extents, + halo_periods=(True, True, True)) + particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width) + + return particle_mesh + + def get_aux_input_from_base_sharding(base_sharding): + + def get_axis_size(sharding, index): + axis_name = sharding.spec[index] + if axis_name == None: + return 1 + else: + return sharding.mesh.shape[sharding.spec[index]] + + return [get_axis_size(base_sharding, i) for i in range(2)] + + +class CICReadOperator(ShardedOperator): + + name = 'cic_read' + + def single_gpu_impl(particle_mesh: jnp.ndarray, + positions: jnp.ndarray, + halo_size=0): + del halo_size + + original_shape = positions.shape + positions = positions.reshape([-1, 3]) + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], + [1., 1, 0], [1., 0, 1], [0., 1, 1], + [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), + jnp.array(particle_mesh.shape)) + + particles = ( + particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], + neighboor_coords[..., 3]] * kernel).sum(axis=-1) + return particles.reshape(original_shape) + + def multi_gpu_prolog(particle_mesh: jnp.ndarray, + positions: jnp.ndarray, + halo_size=0, + __aux_input=None): + + halo_tuple = (halo_size, halo_size) + if __aux_input[0] == 1: + halo_width = ((0, 0), halo_tuple, (0, 0)) + halo_extents = (0, halo_size // 2, 0) + elif __aux_input[1] == 1: + halo_width = (halo_tuple, (0, 0), (0, 0)) + halo_extents = (halo_size // 2, 0, 0) + else: + halo_width = (halo_tuple, halo_tuple, (0, 0)) + halo_extents = (halo_size // 2, halo_size // 2, 0) + + particle_mesh = slice_pad(particle_mesh, pad_width=halo_width) + particle_mesh = halo_exchange(particle_mesh, + halo_extents=halo_extents, + halo_periods=(True, True, True)) + + return particle_mesh, positions, halo_size + + def multi_gpu_impl(particle_mesh: jnp.ndarray, + positions: jnp.ndarray, + halo_size=0, + __aux_input=None): + + original_shape = positions.shape + positions = positions.reshape([-1, 3]) + if __aux_input[0] == 1: + halo_start = [0, halo_size, 0] + elif __aux_input[1] == 1: + halo_start = [halo_size, 0, 0] + else: + halo_start = [halo_size, halo_size, 0] + + positions += jnp.array(halo_start).reshape([-1, 3]) + + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], + [1., 1, 0], [1., 0, 1], [0., 1, 1], + [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), + jnp.array(particle_mesh.shape)) + + particles = ( + particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], + neighboor_coords[..., 3]] * kernel).sum(axis=-1) + return particles.reshape(original_shape) + + +def _chunk_split(ptcl_num, chunk_size, *arrays): + """Split and reshape particle arrays into chunks and remainders, with the remainders + preceding the chunks. 0D ones are duplicated as full arrays in the chunks.""" + chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num) + remainder_size = ptcl_num % chunk_size + chunk_num = ptcl_num // chunk_size + + remainder = None + chunks = arrays + if remainder_size: + remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays] + chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays] + + # `scan` triggers errors in scatter and gather without the `full` + chunks = [ + x.reshape(chunk_num, chunk_size, *x.shape[1:]) + if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks + ] + + return remainder, chunks + + +def enmesh(i1, d1, a1, s1, b12, a2, s2): + """Multilinear enmeshing.""" + i1 = jnp.asarray(i1) + d1 = jnp.asarray(d1) + a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype) + if s1 is not None: + s1 = jnp.array(s1, dtype=i1.dtype) + b12 = jnp.float64(b12) + if a2 is not None: + a2 = jnp.float64(a2) + if s2 is not None: + s2 = jnp.array(s2, dtype=i1.dtype) + + dim = i1.shape[1] + neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> + jnp.arange(dim, dtype=i1.dtype)) & 1 + + if a2 is not None: + P = i1 * a1 + d1 - b12 + P = P[:, jnp.newaxis] # insert neighbor axis + i2 = P + neighbors * a2 # multilinear + + if s1 is not None: + L = s1 * a1 + i2 %= L + + i2 //= a2 + d2 = P - i2 * a2 + + if s1 is not None: + d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected + + i2 = i2.astype(i1.dtype) + d2 = d2.astype(d1.dtype) + a2 = a2.astype(d1.dtype) + + d2 /= a2 + else: + i12, d12 = jnp.divmod(b12, a1) + i1 -= i12.astype(i1.dtype) + d1 -= d12.astype(d1.dtype) + + # insert neighbor axis + i1 = i1[:, jnp.newaxis] + d1 = d1[:, jnp.newaxis] + + # multilinear + d1 /= a1 + i2 = jnp.floor(d1).astype(i1.dtype) + i2 += neighbors + d2 = d1 - i2 + i2 += i1 + + if s1 is not None: + i2 %= s1 + + f2 = 1 - jnp.abs(d2) + + if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None + i2 = jnp.where(i2 < 0, s2, i2) + + f2 = f2.prod(axis=-1) + + return i2, f2 + + +def _scatter_chunk(carry, chunk): + mesh_shape , mesh, offset, cell_size = carry + pmid, disp, val = chunk + spatial_ndim = pmid.shape[1] + spatial_shape = mesh.shape + + # multilinear mesh indices and fractions + ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, + spatial_shape) + # scatter + ind = tuple(ind[..., i] for i in range(spatial_ndim)) + mesh = mesh.at[ind].add(val * frac) + + carry = mesh, offset, cell_size + return carry, None + + +def scatter(pmid, + disp, + mesh, + chunk_size=2**24, + val=1., + offset=0, + cell_size=1.): + + ptcl_num, spatial_ndim = pmid.shape + val = jnp.asarray(val) + mesh = jnp.asarray(mesh) + + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) + carry = mesh.shape , mesh, offset, cell_size + if remainder is not None: + carry = _scatter_chunk(carry, remainder)[0] + carry = scan(_scatter_chunk, carry, chunks)[0] + mesh = carry[0] + return mesh + + +def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None): + """Gather particle values from mesh multilinearly in n-D. + + Parameters + ---------- + ptcl : Particles + conf : Configuration + mesh : ArrayLike + Input mesh. + val : ArrayLike, optional + Input values, can be 0D. + offset : ArrayLike, optional + Offset of mesh to particle grid. If 0D, the value is used in each dimension. + cell_size : float, optional + Mesh cell size in [L]. Default is ``conf.cell_size``. + + Returns + ------- + val : jax.Array + Output values. + + """ + return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size) + + +def _chunk_cat(remainder_array, chunked_array): + """Reshape and concatenate one remainder and one chunked particle arrays.""" + array = chunked_array.reshape(-1, *chunked_array.shape[2:]) + + if remainder_array is not None: + array = jnp.concatenate((remainder_array, array), axis=0) + + return array + + +def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None): + ptcl_num, spatial_ndim = pmid.shape + + mesh = jnp.asarray(mesh) + + val = jnp.asarray(val) + + if mesh.shape[spatial_ndim:] != val.shape[1:]: + raise ValueError('channel shape mismatch: ' + f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}') + + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, + val) + + carry = mesh.shape , mesh, offset, cell_size + val_0 = None + if remainder is not None: + val_0 = _gather_chunk(carry, remainder)[1] + val = scan(_gather_chunk, carry, chunks)[1] + + val = _chunk_cat(val_0, val) + + return val + + +def _gather_chunk(carry, chunk): + mesh_shape , mesh, offset, cell_size = carry + pmid, disp, val = chunk + + spatial_ndim = pmid.shape[1] + + spatial_shape = mesh.shape[:spatial_ndim] + chan_ndim = mesh.ndim - spatial_ndim + chan_axis = tuple(range(-chan_ndim, 0)) + + # multilinear mesh indices and fractions + ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, + cell_size, spatial_shape, False) + + # gather + ind = tuple(ind[..., i] for i in range(spatial_ndim)) + frac = jnp.expand_dims(frac, chan_axis) + val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1) + + return carry, val + + +class CICPaintDXOperator(ShardedOperator): + + name = 'cic_paint_dx' + + def single_gpu_impl(displacement, halo_size=0): + + del halo_size + + original_shape = displacement.shape + + particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') + + a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), + jnp.arange(particle_mesh.shape[1]), + jnp.arange(particle_mesh.shape[2]), + indexing='ij') + + pmid = jnp.stack([b, a, c], axis=-1) + pmid = pmid.reshape([-1, 3]) + return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh) + + def multi_gpu_impl(displacement, halo_size=0, __aux_input=None): + + original_shape = displacement.shape + particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') + + halo_tuple = (halo_size, halo_size) + if __aux_input[0] == 1: + halo_width = ((0, 0), halo_tuple, (0, 0)) + elif __aux_input[1] == 1: + halo_width = (halo_tuple, (0, 0), (0, 0)) + else: + halo_width = (halo_tuple, halo_tuple, (0, 0)) + + particle_mesh = jnp.pad(particle_mesh, halo_width) + + a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), + jnp.arange(particle_mesh.shape[1]), + jnp.arange(particle_mesh.shape[2]), + indexing='ij') + + pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1) + pmid = pmid.reshape([-1, 3]) + return scatter(pmid, displacement.reshape([-1, 3]), + particle_mesh), halo_size + + def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None): + + if __aux_input[0] == 1: + halo_width = (0, halo_size, 0) + halo_extents = (0, halo_size // 2, 0) + elif __aux_input[1] == 1: + halo_width = (halo_size, 0, 0) + halo_extents = (halo_size // 2, 0, 0) + else: + halo_width = (halo_size, halo_size, 0) + halo_extents = (halo_size // 2, halo_size // 2, 0) + + particle_mesh = halo_exchange(particle_mesh, + halo_extents=halo_extents, + halo_periods=(True, True, True)) + particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width) + + return particle_mesh + + +class CICReadDXOperator(ShardedOperator): + + name = 'cic_read_dx' + + def single_gpu_impl(particle_mesh, halo_size=0): + + del halo_size + + original_shape = (*particle_mesh.shape, 3) + + a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), + jnp.arange(particle_mesh.shape[1]), + jnp.arange(particle_mesh.shape[2]), + indexing='ij') + + pmid = jnp.stack([b, a, c], axis=-1) + pmid = pmid.reshape([-1, 3]) + return _gather(pmid, jnp.zeros_like(pmid), particle_mesh) + + def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None): + + halo_tuple = (halo_size, halo_size) + if __aux_input[0] == 1: + halo_width = ((0, 0), halo_tuple, (0, 0)) + halo_extents = (0, halo_size // 2, 0) + elif __aux_input[1] == 1: + halo_width = (halo_tuple, (0, 0), (0, 0)) + halo_extents = (halo_size // 2, 0, 0) + else: + halo_width = (halo_tuple, halo_tuple, (0, 0)) + halo_extents = (halo_size // 2, halo_size // 2, 0) + + particle_mesh = slice_pad(particle_mesh, pad_width=halo_width) + particle_mesh = halo_exchange(particle_mesh, + halo_extents=halo_extents, + halo_periods=(True, True, True)) + + return particle_mesh, halo_size + + + def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None): + + original_shape = (*particle_mesh.shape, 3) + halo_tuple = (halo_size, halo_size) + if __aux_input[0] == 1: + halo_width = ((0, 0), halo_tuple, (0, 0)) + elif __aux_input[1] == 1: + halo_width = (halo_tuple, (0, 0), (0, 0)) + else: + halo_width = (halo_tuple, halo_tuple, (0, 0)) + + particle_mesh = jnp.pad(particle_mesh, halo_width) + + a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), + jnp.arange(particle_mesh.shape[1]), + jnp.arange(particle_mesh.shape[2]), + indexing='ij') + + pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1) + pmid = pmid.reshape([-1, 3]) + # TODO must be reshaped + return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size + + +register_operator(CICPaintOperator) +register_operator(CICReadOperator) +register_operator(CICPaintDXOperator)