mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 00:51:11 +00:00
Implement Ops
This commit is contained in:
parent
e708f5b176
commit
bc2612a198
2 changed files with 785 additions and 0 deletions
559
jaxpm/_src/painting_ops.py
Normal file
559
jaxpm/_src/painting_ops.py
Normal file
|
@ -0,0 +1,559 @@
|
|||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.lax import scan
|
||||
from jaxdecomp import halo_exchange
|
||||
|
||||
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
|
||||
ShardedOperator, register_operator)
|
||||
from jaxpm.ops import slice_pad, slice_unpad
|
||||
|
||||
|
||||
class CICPaintOperator(ShardedOperator):
|
||||
|
||||
name = 'cic_paint'
|
||||
|
||||
def single_gpu_impl(particle_mesh: jnp.ndarray,
|
||||
positions: jnp.ndarray,
|
||||
halo_size=0):
|
||||
|
||||
del halo_size
|
||||
|
||||
positions = positions.reshape([-1, 3])
|
||||
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1],
|
||||
[1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
neighboor_coords_mod = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
jnp.array(particle_mesh.shape))
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||
particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod,
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
|
||||
return particle_mesh
|
||||
|
||||
def multi_gpu_impl(particle_mesh: jnp.ndarray,
|
||||
positions: jnp.ndarray,
|
||||
halo_size=8,
|
||||
__aux_input=None):
|
||||
|
||||
rank = jax.process_index()
|
||||
correct_y = -particle_mesh.shape[1] * (rank // __aux_input[0])
|
||||
correct_z = -particle_mesh.shape[0] * (rank % __aux_input[1])
|
||||
# Get positions relative to the start of each slice
|
||||
positions = positions.at[:, :, :, 1].add(correct_y)
|
||||
positions = positions.at[:, :, :, 0].add(correct_z)
|
||||
positions = positions.reshape([-1, 3])
|
||||
|
||||
halo_tuple = (halo_size, halo_size)
|
||||
if __aux_input[0] == 1:
|
||||
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||
halo_start = [0, halo_size, 0]
|
||||
elif __aux_input[1] == 1:
|
||||
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||
halo_start = [halo_size, 0, 0]
|
||||
else:
|
||||
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||
halo_start = [halo_size, halo_size, 0]
|
||||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_width)
|
||||
positions += jnp.array(halo_start).reshape([-1, 3])
|
||||
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1],
|
||||
[1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
neighboor_coords_mod = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
jnp.array(particle_mesh.shape))
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||
particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod,
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
|
||||
return particle_mesh, halo_size
|
||||
|
||||
def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None):
|
||||
|
||||
if __aux_input[0] == 1:
|
||||
halo_width = (0, halo_size, 0)
|
||||
halo_extents = (0, halo_size // 2, 0)
|
||||
elif __aux_input[1] == 1:
|
||||
halo_width = (halo_size, 0, 0)
|
||||
halo_extents = (halo_size // 2, 0, 0)
|
||||
else:
|
||||
halo_width = (halo_size, halo_size, 0)
|
||||
halo_extents = (halo_size // 2, halo_size // 2, 0)
|
||||
|
||||
particle_mesh = halo_exchange(particle_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width)
|
||||
|
||||
return particle_mesh
|
||||
|
||||
def get_aux_input_from_base_sharding(base_sharding):
|
||||
|
||||
def get_axis_size(sharding, index):
|
||||
axis_name = sharding.spec[index]
|
||||
if axis_name == None:
|
||||
return 1
|
||||
else:
|
||||
return sharding.mesh.shape[sharding.spec[index]]
|
||||
|
||||
return [get_axis_size(base_sharding, i) for i in range(2)]
|
||||
|
||||
|
||||
class CICReadOperator(ShardedOperator):
|
||||
|
||||
name = 'cic_read'
|
||||
|
||||
def single_gpu_impl(particle_mesh: jnp.ndarray,
|
||||
positions: jnp.ndarray,
|
||||
halo_size=0):
|
||||
del halo_size
|
||||
|
||||
original_shape = positions.shape
|
||||
positions = positions.reshape([-1, 3])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1],
|
||||
[1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||
jnp.array(particle_mesh.shape))
|
||||
|
||||
particles = (
|
||||
particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||
return particles.reshape(original_shape)
|
||||
|
||||
def multi_gpu_prolog(particle_mesh: jnp.ndarray,
|
||||
positions: jnp.ndarray,
|
||||
halo_size=0,
|
||||
__aux_input=None):
|
||||
|
||||
halo_tuple = (halo_size, halo_size)
|
||||
if __aux_input[0] == 1:
|
||||
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||
halo_extents = (0, halo_size // 2, 0)
|
||||
elif __aux_input[1] == 1:
|
||||
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||
halo_extents = (halo_size // 2, 0, 0)
|
||||
else:
|
||||
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||
halo_extents = (halo_size // 2, halo_size // 2, 0)
|
||||
|
||||
particle_mesh = slice_pad(particle_mesh, pad_width=halo_width)
|
||||
particle_mesh = halo_exchange(particle_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
|
||||
return particle_mesh, positions, halo_size
|
||||
|
||||
def multi_gpu_impl(particle_mesh: jnp.ndarray,
|
||||
positions: jnp.ndarray,
|
||||
halo_size=0,
|
||||
__aux_input=None):
|
||||
|
||||
original_shape = positions.shape
|
||||
positions = positions.reshape([-1, 3])
|
||||
if __aux_input[0] == 1:
|
||||
halo_start = [0, halo_size, 0]
|
||||
elif __aux_input[1] == 1:
|
||||
halo_start = [halo_size, 0, 0]
|
||||
else:
|
||||
halo_start = [halo_size, halo_size, 0]
|
||||
|
||||
positions += jnp.array(halo_start).reshape([-1, 3])
|
||||
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1],
|
||||
[1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||
jnp.array(particle_mesh.shape))
|
||||
|
||||
particles = (
|
||||
particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||
return particles.reshape(original_shape)
|
||||
|
||||
|
||||
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
||||
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
|
||||
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
|
||||
remainder_size = ptcl_num % chunk_size
|
||||
chunk_num = ptcl_num // chunk_size
|
||||
|
||||
remainder = None
|
||||
chunks = arrays
|
||||
if remainder_size:
|
||||
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
|
||||
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
|
||||
|
||||
# `scan` triggers errors in scatter and gather without the `full`
|
||||
chunks = [
|
||||
x.reshape(chunk_num, chunk_size, *x.shape[1:])
|
||||
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
|
||||
]
|
||||
|
||||
return remainder, chunks
|
||||
|
||||
|
||||
def enmesh(i1, d1, a1, s1, b12, a2, s2):
|
||||
"""Multilinear enmeshing."""
|
||||
i1 = jnp.asarray(i1)
|
||||
d1 = jnp.asarray(d1)
|
||||
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
|
||||
if s1 is not None:
|
||||
s1 = jnp.array(s1, dtype=i1.dtype)
|
||||
b12 = jnp.float64(b12)
|
||||
if a2 is not None:
|
||||
a2 = jnp.float64(a2)
|
||||
if s2 is not None:
|
||||
s2 = jnp.array(s2, dtype=i1.dtype)
|
||||
|
||||
dim = i1.shape[1]
|
||||
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
|
||||
jnp.arange(dim, dtype=i1.dtype)) & 1
|
||||
|
||||
if a2 is not None:
|
||||
P = i1 * a1 + d1 - b12
|
||||
P = P[:, jnp.newaxis] # insert neighbor axis
|
||||
i2 = P + neighbors * a2 # multilinear
|
||||
|
||||
if s1 is not None:
|
||||
L = s1 * a1
|
||||
i2 %= L
|
||||
|
||||
i2 //= a2
|
||||
d2 = P - i2 * a2
|
||||
|
||||
if s1 is not None:
|
||||
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
|
||||
|
||||
i2 = i2.astype(i1.dtype)
|
||||
d2 = d2.astype(d1.dtype)
|
||||
a2 = a2.astype(d1.dtype)
|
||||
|
||||
d2 /= a2
|
||||
else:
|
||||
i12, d12 = jnp.divmod(b12, a1)
|
||||
i1 -= i12.astype(i1.dtype)
|
||||
d1 -= d12.astype(d1.dtype)
|
||||
|
||||
# insert neighbor axis
|
||||
i1 = i1[:, jnp.newaxis]
|
||||
d1 = d1[:, jnp.newaxis]
|
||||
|
||||
# multilinear
|
||||
d1 /= a1
|
||||
i2 = jnp.floor(d1).astype(i1.dtype)
|
||||
i2 += neighbors
|
||||
d2 = d1 - i2
|
||||
i2 += i1
|
||||
|
||||
if s1 is not None:
|
||||
i2 %= s1
|
||||
|
||||
f2 = 1 - jnp.abs(d2)
|
||||
|
||||
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
|
||||
i2 = jnp.where(i2 < 0, s2, i2)
|
||||
|
||||
f2 = f2.prod(axis=-1)
|
||||
|
||||
return i2, f2
|
||||
|
||||
|
||||
def _scatter_chunk(carry, chunk):
|
||||
mesh_shape , mesh, offset, cell_size = carry
|
||||
pmid, disp, val = chunk
|
||||
spatial_ndim = pmid.shape[1]
|
||||
spatial_shape = mesh.shape
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
# scatter
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
mesh = mesh.at[ind].add(val * frac)
|
||||
|
||||
carry = mesh, offset, cell_size
|
||||
return carry, None
|
||||
|
||||
|
||||
def scatter(pmid,
|
||||
disp,
|
||||
mesh,
|
||||
chunk_size=2**24,
|
||||
val=1.,
|
||||
offset=0,
|
||||
cell_size=1.):
|
||||
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
val = jnp.asarray(val)
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
carry = mesh.shape , mesh, offset, cell_size
|
||||
if remainder is not None:
|
||||
carry = _scatter_chunk(carry, remainder)[0]
|
||||
carry = scan(_scatter_chunk, carry, chunks)[0]
|
||||
mesh = carry[0]
|
||||
return mesh
|
||||
|
||||
|
||||
def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None):
|
||||
"""Gather particle values from mesh multilinearly in n-D.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ptcl : Particles
|
||||
conf : Configuration
|
||||
mesh : ArrayLike
|
||||
Input mesh.
|
||||
val : ArrayLike, optional
|
||||
Input values, can be 0D.
|
||||
offset : ArrayLike, optional
|
||||
Offset of mesh to particle grid. If 0D, the value is used in each dimension.
|
||||
cell_size : float, optional
|
||||
Mesh cell size in [L]. Default is ``conf.cell_size``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
val : jax.Array
|
||||
Output values.
|
||||
|
||||
"""
|
||||
return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size)
|
||||
|
||||
|
||||
def _chunk_cat(remainder_array, chunked_array):
|
||||
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
||||
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
||||
|
||||
if remainder_array is not None:
|
||||
array = jnp.concatenate((remainder_array, array), axis=0)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
||||
val = jnp.asarray(val)
|
||||
|
||||
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||
raise ValueError('channel shape mismatch: '
|
||||
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
||||
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp,
|
||||
val)
|
||||
|
||||
carry = mesh.shape , mesh, offset, cell_size
|
||||
val_0 = None
|
||||
if remainder is not None:
|
||||
val_0 = _gather_chunk(carry, remainder)[1]
|
||||
val = scan(_gather_chunk, carry, chunks)[1]
|
||||
|
||||
val = _chunk_cat(val_0, val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def _gather_chunk(carry, chunk):
|
||||
mesh_shape , mesh, offset, cell_size = carry
|
||||
pmid, disp, val = chunk
|
||||
|
||||
spatial_ndim = pmid.shape[1]
|
||||
|
||||
spatial_shape = mesh.shape[:spatial_ndim]
|
||||
chan_ndim = mesh.ndim - spatial_ndim
|
||||
chan_axis = tuple(range(-chan_ndim, 0))
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset,
|
||||
cell_size, spatial_shape, False)
|
||||
|
||||
# gather
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
frac = jnp.expand_dims(frac, chan_axis)
|
||||
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
|
||||
|
||||
return carry, val
|
||||
|
||||
|
||||
class CICPaintDXOperator(ShardedOperator):
|
||||
|
||||
name = 'cic_paint_dx'
|
||||
|
||||
def single_gpu_impl(displacement, halo_size=0):
|
||||
|
||||
del halo_size
|
||||
|
||||
original_shape = displacement.shape
|
||||
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([b, a, c], axis=-1)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh)
|
||||
|
||||
def multi_gpu_impl(displacement, halo_size=0, __aux_input=None):
|
||||
|
||||
original_shape = displacement.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
|
||||
halo_tuple = (halo_size, halo_size)
|
||||
if __aux_input[0] == 1:
|
||||
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||
elif __aux_input[1] == 1:
|
||||
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||
else:
|
||||
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_width)
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
return scatter(pmid, displacement.reshape([-1, 3]),
|
||||
particle_mesh), halo_size
|
||||
|
||||
def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None):
|
||||
|
||||
if __aux_input[0] == 1:
|
||||
halo_width = (0, halo_size, 0)
|
||||
halo_extents = (0, halo_size // 2, 0)
|
||||
elif __aux_input[1] == 1:
|
||||
halo_width = (halo_size, 0, 0)
|
||||
halo_extents = (halo_size // 2, 0, 0)
|
||||
else:
|
||||
halo_width = (halo_size, halo_size, 0)
|
||||
halo_extents = (halo_size // 2, halo_size // 2, 0)
|
||||
|
||||
particle_mesh = halo_exchange(particle_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width)
|
||||
|
||||
return particle_mesh
|
||||
|
||||
|
||||
class CICReadDXOperator(ShardedOperator):
|
||||
|
||||
name = 'cic_read_dx'
|
||||
|
||||
def single_gpu_impl(particle_mesh, halo_size=0):
|
||||
|
||||
del halo_size
|
||||
|
||||
original_shape = (*particle_mesh.shape, 3)
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([b, a, c], axis=-1)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
|
||||
|
||||
def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None):
|
||||
|
||||
halo_tuple = (halo_size, halo_size)
|
||||
if __aux_input[0] == 1:
|
||||
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||
halo_extents = (0, halo_size // 2, 0)
|
||||
elif __aux_input[1] == 1:
|
||||
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||
halo_extents = (halo_size // 2, 0, 0)
|
||||
else:
|
||||
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||
halo_extents = (halo_size // 2, halo_size // 2, 0)
|
||||
|
||||
particle_mesh = slice_pad(particle_mesh, pad_width=halo_width)
|
||||
particle_mesh = halo_exchange(particle_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
|
||||
return particle_mesh, halo_size
|
||||
|
||||
|
||||
def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None):
|
||||
|
||||
original_shape = (*particle_mesh.shape, 3)
|
||||
halo_tuple = (halo_size, halo_size)
|
||||
if __aux_input[0] == 1:
|
||||
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||
elif __aux_input[1] == 1:
|
||||
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||
else:
|
||||
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_width)
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
# TODO must be reshaped
|
||||
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size
|
||||
|
||||
|
||||
register_operator(CICPaintOperator)
|
||||
register_operator(CICReadOperator)
|
||||
register_operator(CICPaintDXOperator)
|
Loading…
Add table
Add a link
Reference in a new issue