mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
Implement Ops
This commit is contained in:
parent
e708f5b176
commit
bc2612a198
2 changed files with 785 additions and 0 deletions
226
jaxpm/_src/base_ops.py
Normal file
226
jaxpm/_src/base_ops.py
Normal file
|
@ -0,0 +1,226 @@
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from inspect import signature
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jaxdecomp
|
||||||
|
import numpy as np
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from jaxdecomp import halo_exchange
|
||||||
|
from jaxdecomp.fft import pfft3d, pifft3d
|
||||||
|
|
||||||
|
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
|
||||||
|
ShardedOperator, register_operator)
|
||||||
|
|
||||||
|
|
||||||
|
class FFTOperator(CustomPartionedOperator):
|
||||||
|
|
||||||
|
name = 'fftn'
|
||||||
|
|
||||||
|
def single_gpu_impl(x):
|
||||||
|
return jnp.fft.fftn(x)
|
||||||
|
|
||||||
|
def multi_gpu_impl(x):
|
||||||
|
return pfft3d(x)
|
||||||
|
|
||||||
|
|
||||||
|
class IFFTOperator(CustomPartionedOperator):
|
||||||
|
|
||||||
|
name = 'ifftn'
|
||||||
|
|
||||||
|
def single_gpu_impl(x):
|
||||||
|
return jnp.fft.ifftn(x)
|
||||||
|
|
||||||
|
def multi_gpu_impl(x):
|
||||||
|
return pifft3d(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HaloExchangeOperator(CustomPartionedOperator):
|
||||||
|
|
||||||
|
name = 'halo_exchange'
|
||||||
|
|
||||||
|
# Halo exchange does nothing on a single GPU
|
||||||
|
# Inside a jit , this will be optimized out
|
||||||
|
def single_gpu_impl(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def multi_gpu_impl(x):
|
||||||
|
return halo_exchange(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Padding and unpadding operators should not do anything in case of single GPU
|
||||||
|
# Since there is no halo exchange for a single GPU
|
||||||
|
# Inside a jit , this will be optimized out
|
||||||
|
|
||||||
|
|
||||||
|
class PaddingOperator(ShardedOperator):
|
||||||
|
|
||||||
|
name = 'slice_pad'
|
||||||
|
|
||||||
|
def single_gpu_impl(x, pad_width):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def multi_gpu_impl(x, pad_width):
|
||||||
|
return jnp.pad(x, pad_width)
|
||||||
|
|
||||||
|
def infer_sharding_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
in_spec = base_sharding, P()
|
||||||
|
out_spec = base_sharding
|
||||||
|
|
||||||
|
return in_spec, out_spec
|
||||||
|
|
||||||
|
|
||||||
|
class UnpaddingOperator(ShardedOperator):
|
||||||
|
|
||||||
|
name = 'slice_unpad'
|
||||||
|
|
||||||
|
def single_gpu_impl(x, pad_width):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def multi_gpu_impl(x, pad_width):
|
||||||
|
|
||||||
|
# WARNING : unequal halo size is not supported
|
||||||
|
halo_x, _ = pad_width[0]
|
||||||
|
halo_y, _ = pad_width[0]
|
||||||
|
|
||||||
|
# Apply corrections along x
|
||||||
|
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
|
||||||
|
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
|
||||||
|
# Apply corrections along y
|
||||||
|
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||||
|
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||||
|
|
||||||
|
return x[halo_x:-halo_x, halo_y:-halo_y, :]
|
||||||
|
|
||||||
|
def infer_sharding_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
in_spec = base_sharding, P()
|
||||||
|
out_spec = base_sharding
|
||||||
|
|
||||||
|
return in_spec, out_spec
|
||||||
|
|
||||||
|
|
||||||
|
class NormalFieldOperator(CallBackOperator):
|
||||||
|
|
||||||
|
name = 'normal'
|
||||||
|
|
||||||
|
def single_gpu_impl(shape, key, dtype='float32'):
|
||||||
|
return jax.random.normal(key, shape, dtype=dtype)
|
||||||
|
|
||||||
|
def multi_gpu_impl(shape, key, dtype='float32', base_sharding=None):
|
||||||
|
|
||||||
|
assert (isinstance(base_sharding, NamedSharding))
|
||||||
|
sharding = NormalFieldOperator.shardings_to_use_in_impl(base_sharding)
|
||||||
|
|
||||||
|
def get_axis_size(sharding, index):
|
||||||
|
axis_name = sharding.spec[index]
|
||||||
|
if axis_name == None:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return sharding.mesh.shape[sharding.spec[index]]
|
||||||
|
|
||||||
|
pdims = [get_axis_size(sharding, i) for i in range(2)]
|
||||||
|
local_mesh_shape = [
|
||||||
|
shape[0] // pdims[1], shape[1] // pdims[0], shape[2]
|
||||||
|
]
|
||||||
|
|
||||||
|
return jax.make_array_from_single_device_arrays(
|
||||||
|
shape=shape,
|
||||||
|
sharding=sharding,
|
||||||
|
arrays=[jax.random.normal(key, local_mesh_shape, dtype=dtype)])
|
||||||
|
|
||||||
|
def shardings_to_use_in_impl(base_sharding):
|
||||||
|
return base_sharding
|
||||||
|
|
||||||
|
|
||||||
|
class FFTKOperator(CallBackOperator):
|
||||||
|
name = 'fftk'
|
||||||
|
|
||||||
|
def single_gpu_impl(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||||
|
k = []
|
||||||
|
for d in range(len(shape)):
|
||||||
|
kd = np.fft.fftfreq(shape[d])
|
||||||
|
kd *= 2 * np.pi
|
||||||
|
kdshape = np.ones(len(shape), dtype='int')
|
||||||
|
if symmetric and d == len(shape) - 1:
|
||||||
|
kd = kd[:shape[d] // 2 + 1]
|
||||||
|
kdshape[d] = len(kd)
|
||||||
|
kd = kd.reshape(kdshape)
|
||||||
|
|
||||||
|
k.append(kd.astype(dtype))
|
||||||
|
del kd, kdshape
|
||||||
|
return k
|
||||||
|
|
||||||
|
def multi_gpu_impl(shape,
|
||||||
|
symmetric=True,
|
||||||
|
finite=False,
|
||||||
|
dtype=np.float32,
|
||||||
|
base_sharding=None):
|
||||||
|
|
||||||
|
assert (isinstance(base_sharding, NamedSharding))
|
||||||
|
kvec = FFTKOperator.single_gpu_impl(shape, symmetric, finite, dtype)
|
||||||
|
|
||||||
|
z_sharding, y_sharding = FFTKOperator.shardings_to_use_in_impl(shape)
|
||||||
|
|
||||||
|
return [
|
||||||
|
jax.make_array_from_callback(
|
||||||
|
(shape[0], 1, 1),
|
||||||
|
sharding=z_sharding,
|
||||||
|
data_callback=lambda x: kvec[0].reshape([-1, 1, 1])[x]),
|
||||||
|
jax.make_array_from_callback(
|
||||||
|
(1, shape[1], 1),
|
||||||
|
sharding=y_sharding,
|
||||||
|
data_callback=lambda x: kvec[1].reshape([1, -1, 1])[x]),
|
||||||
|
kvec[2].reshape([1, 1, -1])
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shardings_to_use_in_impl(base_sharding):
|
||||||
|
spec = base_sharding.spec
|
||||||
|
|
||||||
|
z_sharding = NamedSharding(P(spec[0], None, None))
|
||||||
|
y_sharding = NamedSharding(P(None, spec[1], None))
|
||||||
|
|
||||||
|
return z_sharding, y_sharding
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateParticlesOperator(CallBackOperator):
|
||||||
|
|
||||||
|
name = 'generate_initial_positions'
|
||||||
|
|
||||||
|
def single_gpu_impl(shape):
|
||||||
|
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in shape]),
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
def multi_gpu_impl(shape, base_sharding=None):
|
||||||
|
assert (isinstance(base_sharding, NamedSharding))
|
||||||
|
sharding = GenerateParticlesOperator.shardings_to_use_in_impl(
|
||||||
|
base_sharding)
|
||||||
|
|
||||||
|
return jax.make_array_from_callback(
|
||||||
|
shape=tuple([*shape, 3]),
|
||||||
|
sharding=sharding,
|
||||||
|
data_callback=lambda x: jnp.stack(jnp.meshgrid(
|
||||||
|
jnp.arange(shape[0])[x[0]],
|
||||||
|
jnp.arange(shape[1])[x[1]],
|
||||||
|
jnp.arange(shape[2]),
|
||||||
|
indexing='ij'),
|
||||||
|
axis=-1))
|
||||||
|
|
||||||
|
def shardings_to_use_in_impl(base_sharding):
|
||||||
|
return base_sharding
|
||||||
|
|
||||||
|
|
||||||
|
register_operator(FFTOperator)
|
||||||
|
register_operator(IFFTOperator)
|
||||||
|
register_operator(HaloExchangeOperator)
|
||||||
|
register_operator(PaddingOperator)
|
||||||
|
register_operator(UnpaddingOperator)
|
||||||
|
register_operator(NormalFieldOperator)
|
||||||
|
register_operator(FFTKOperator)
|
||||||
|
register_operator(GenerateParticlesOperator)
|
559
jaxpm/_src/painting_ops.py
Normal file
559
jaxpm/_src/painting_ops.py
Normal file
|
@ -0,0 +1,559 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import lax
|
||||||
|
from jax.lax import scan
|
||||||
|
from jaxdecomp import halo_exchange
|
||||||
|
|
||||||
|
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
|
||||||
|
ShardedOperator, register_operator)
|
||||||
|
from jaxpm.ops import slice_pad, slice_unpad
|
||||||
|
|
||||||
|
|
||||||
|
class CICPaintOperator(ShardedOperator):
|
||||||
|
|
||||||
|
name = 'cic_paint'
|
||||||
|
|
||||||
|
def single_gpu_impl(particle_mesh: jnp.ndarray,
|
||||||
|
positions: jnp.ndarray,
|
||||||
|
halo_size=0):
|
||||||
|
|
||||||
|
del halo_size
|
||||||
|
|
||||||
|
positions = positions.reshape([-1, 3])
|
||||||
|
|
||||||
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
|
||||||
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
|
[1., 1, 0], [1., 0, 1], [0., 1, 1],
|
||||||
|
[1., 1, 1]]])
|
||||||
|
|
||||||
|
neighboor_coords = floor + connection
|
||||||
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
|
neighboor_coords_mod = jnp.mod(
|
||||||
|
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||||
|
jnp.array(particle_mesh.shape))
|
||||||
|
|
||||||
|
dnums = jax.lax.ScatterDimensionNumbers(
|
||||||
|
update_window_dims=(),
|
||||||
|
inserted_window_dims=(0, 1, 2),
|
||||||
|
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||||
|
particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod,
|
||||||
|
kernel.reshape([-1, 8]), dnums)
|
||||||
|
|
||||||
|
return particle_mesh
|
||||||
|
|
||||||
|
def multi_gpu_impl(particle_mesh: jnp.ndarray,
|
||||||
|
positions: jnp.ndarray,
|
||||||
|
halo_size=8,
|
||||||
|
__aux_input=None):
|
||||||
|
|
||||||
|
rank = jax.process_index()
|
||||||
|
correct_y = -particle_mesh.shape[1] * (rank // __aux_input[0])
|
||||||
|
correct_z = -particle_mesh.shape[0] * (rank % __aux_input[1])
|
||||||
|
# Get positions relative to the start of each slice
|
||||||
|
positions = positions.at[:, :, :, 1].add(correct_y)
|
||||||
|
positions = positions.at[:, :, :, 0].add(correct_z)
|
||||||
|
positions = positions.reshape([-1, 3])
|
||||||
|
|
||||||
|
halo_tuple = (halo_size, halo_size)
|
||||||
|
if __aux_input[0] == 1:
|
||||||
|
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||||
|
halo_start = [0, halo_size, 0]
|
||||||
|
elif __aux_input[1] == 1:
|
||||||
|
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||||
|
halo_start = [halo_size, 0, 0]
|
||||||
|
else:
|
||||||
|
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||||
|
halo_start = [halo_size, halo_size, 0]
|
||||||
|
|
||||||
|
particle_mesh = jnp.pad(particle_mesh, halo_width)
|
||||||
|
positions += jnp.array(halo_start).reshape([-1, 3])
|
||||||
|
|
||||||
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
|
||||||
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
|
[1., 1, 0], [1., 0, 1], [0., 1, 1],
|
||||||
|
[1., 1, 1]]])
|
||||||
|
|
||||||
|
neighboor_coords = floor + connection
|
||||||
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
|
neighboor_coords_mod = jnp.mod(
|
||||||
|
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||||
|
jnp.array(particle_mesh.shape))
|
||||||
|
|
||||||
|
dnums = jax.lax.ScatterDimensionNumbers(
|
||||||
|
update_window_dims=(),
|
||||||
|
inserted_window_dims=(0, 1, 2),
|
||||||
|
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||||
|
particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod,
|
||||||
|
kernel.reshape([-1, 8]), dnums)
|
||||||
|
|
||||||
|
return particle_mesh, halo_size
|
||||||
|
|
||||||
|
def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None):
|
||||||
|
|
||||||
|
if __aux_input[0] == 1:
|
||||||
|
halo_width = (0, halo_size, 0)
|
||||||
|
halo_extents = (0, halo_size // 2, 0)
|
||||||
|
elif __aux_input[1] == 1:
|
||||||
|
halo_width = (halo_size, 0, 0)
|
||||||
|
halo_extents = (halo_size // 2, 0, 0)
|
||||||
|
else:
|
||||||
|
halo_width = (halo_size, halo_size, 0)
|
||||||
|
halo_extents = (halo_size // 2, halo_size // 2, 0)
|
||||||
|
|
||||||
|
particle_mesh = halo_exchange(particle_mesh,
|
||||||
|
halo_extents=halo_extents,
|
||||||
|
halo_periods=(True, True, True))
|
||||||
|
particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width)
|
||||||
|
|
||||||
|
return particle_mesh
|
||||||
|
|
||||||
|
def get_aux_input_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
def get_axis_size(sharding, index):
|
||||||
|
axis_name = sharding.spec[index]
|
||||||
|
if axis_name == None:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return sharding.mesh.shape[sharding.spec[index]]
|
||||||
|
|
||||||
|
return [get_axis_size(base_sharding, i) for i in range(2)]
|
||||||
|
|
||||||
|
|
||||||
|
class CICReadOperator(ShardedOperator):
|
||||||
|
|
||||||
|
name = 'cic_read'
|
||||||
|
|
||||||
|
def single_gpu_impl(particle_mesh: jnp.ndarray,
|
||||||
|
positions: jnp.ndarray,
|
||||||
|
halo_size=0):
|
||||||
|
del halo_size
|
||||||
|
|
||||||
|
original_shape = positions.shape
|
||||||
|
positions = positions.reshape([-1, 3])
|
||||||
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
|
[1., 1, 0], [1., 0, 1], [0., 1, 1],
|
||||||
|
[1., 1, 1]]])
|
||||||
|
|
||||||
|
neighboor_coords = floor + connection
|
||||||
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
|
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||||
|
jnp.array(particle_mesh.shape))
|
||||||
|
|
||||||
|
particles = (
|
||||||
|
particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||||
|
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||||
|
return particles.reshape(original_shape)
|
||||||
|
|
||||||
|
def multi_gpu_prolog(particle_mesh: jnp.ndarray,
|
||||||
|
positions: jnp.ndarray,
|
||||||
|
halo_size=0,
|
||||||
|
__aux_input=None):
|
||||||
|
|
||||||
|
halo_tuple = (halo_size, halo_size)
|
||||||
|
if __aux_input[0] == 1:
|
||||||
|
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||||
|
halo_extents = (0, halo_size // 2, 0)
|
||||||
|
elif __aux_input[1] == 1:
|
||||||
|
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||||
|
halo_extents = (halo_size // 2, 0, 0)
|
||||||
|
else:
|
||||||
|
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||||
|
halo_extents = (halo_size // 2, halo_size // 2, 0)
|
||||||
|
|
||||||
|
particle_mesh = slice_pad(particle_mesh, pad_width=halo_width)
|
||||||
|
particle_mesh = halo_exchange(particle_mesh,
|
||||||
|
halo_extents=halo_extents,
|
||||||
|
halo_periods=(True, True, True))
|
||||||
|
|
||||||
|
return particle_mesh, positions, halo_size
|
||||||
|
|
||||||
|
def multi_gpu_impl(particle_mesh: jnp.ndarray,
|
||||||
|
positions: jnp.ndarray,
|
||||||
|
halo_size=0,
|
||||||
|
__aux_input=None):
|
||||||
|
|
||||||
|
original_shape = positions.shape
|
||||||
|
positions = positions.reshape([-1, 3])
|
||||||
|
if __aux_input[0] == 1:
|
||||||
|
halo_start = [0, halo_size, 0]
|
||||||
|
elif __aux_input[1] == 1:
|
||||||
|
halo_start = [halo_size, 0, 0]
|
||||||
|
else:
|
||||||
|
halo_start = [halo_size, halo_size, 0]
|
||||||
|
|
||||||
|
positions += jnp.array(halo_start).reshape([-1, 3])
|
||||||
|
|
||||||
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
floor = jnp.floor(positions)
|
||||||
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||||
|
[1., 1, 0], [1., 0, 1], [0., 1, 1],
|
||||||
|
[1., 1, 1]]])
|
||||||
|
|
||||||
|
neighboor_coords = floor + connection
|
||||||
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
|
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||||
|
jnp.array(particle_mesh.shape))
|
||||||
|
|
||||||
|
particles = (
|
||||||
|
particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||||
|
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||||
|
return particles.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||||
|
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
||||||
|
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
|
||||||
|
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
|
||||||
|
remainder_size = ptcl_num % chunk_size
|
||||||
|
chunk_num = ptcl_num // chunk_size
|
||||||
|
|
||||||
|
remainder = None
|
||||||
|
chunks = arrays
|
||||||
|
if remainder_size:
|
||||||
|
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
|
||||||
|
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
|
||||||
|
|
||||||
|
# `scan` triggers errors in scatter and gather without the `full`
|
||||||
|
chunks = [
|
||||||
|
x.reshape(chunk_num, chunk_size, *x.shape[1:])
|
||||||
|
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
return remainder, chunks
|
||||||
|
|
||||||
|
|
||||||
|
def enmesh(i1, d1, a1, s1, b12, a2, s2):
|
||||||
|
"""Multilinear enmeshing."""
|
||||||
|
i1 = jnp.asarray(i1)
|
||||||
|
d1 = jnp.asarray(d1)
|
||||||
|
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype)
|
||||||
|
if s1 is not None:
|
||||||
|
s1 = jnp.array(s1, dtype=i1.dtype)
|
||||||
|
b12 = jnp.float64(b12)
|
||||||
|
if a2 is not None:
|
||||||
|
a2 = jnp.float64(a2)
|
||||||
|
if s2 is not None:
|
||||||
|
s2 = jnp.array(s2, dtype=i1.dtype)
|
||||||
|
|
||||||
|
dim = i1.shape[1]
|
||||||
|
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
|
||||||
|
jnp.arange(dim, dtype=i1.dtype)) & 1
|
||||||
|
|
||||||
|
if a2 is not None:
|
||||||
|
P = i1 * a1 + d1 - b12
|
||||||
|
P = P[:, jnp.newaxis] # insert neighbor axis
|
||||||
|
i2 = P + neighbors * a2 # multilinear
|
||||||
|
|
||||||
|
if s1 is not None:
|
||||||
|
L = s1 * a1
|
||||||
|
i2 %= L
|
||||||
|
|
||||||
|
i2 //= a2
|
||||||
|
d2 = P - i2 * a2
|
||||||
|
|
||||||
|
if s1 is not None:
|
||||||
|
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
|
||||||
|
|
||||||
|
i2 = i2.astype(i1.dtype)
|
||||||
|
d2 = d2.astype(d1.dtype)
|
||||||
|
a2 = a2.astype(d1.dtype)
|
||||||
|
|
||||||
|
d2 /= a2
|
||||||
|
else:
|
||||||
|
i12, d12 = jnp.divmod(b12, a1)
|
||||||
|
i1 -= i12.astype(i1.dtype)
|
||||||
|
d1 -= d12.astype(d1.dtype)
|
||||||
|
|
||||||
|
# insert neighbor axis
|
||||||
|
i1 = i1[:, jnp.newaxis]
|
||||||
|
d1 = d1[:, jnp.newaxis]
|
||||||
|
|
||||||
|
# multilinear
|
||||||
|
d1 /= a1
|
||||||
|
i2 = jnp.floor(d1).astype(i1.dtype)
|
||||||
|
i2 += neighbors
|
||||||
|
d2 = d1 - i2
|
||||||
|
i2 += i1
|
||||||
|
|
||||||
|
if s1 is not None:
|
||||||
|
i2 %= s1
|
||||||
|
|
||||||
|
f2 = 1 - jnp.abs(d2)
|
||||||
|
|
||||||
|
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
|
||||||
|
i2 = jnp.where(i2 < 0, s2, i2)
|
||||||
|
|
||||||
|
f2 = f2.prod(axis=-1)
|
||||||
|
|
||||||
|
return i2, f2
|
||||||
|
|
||||||
|
|
||||||
|
def _scatter_chunk(carry, chunk):
|
||||||
|
mesh_shape , mesh, offset, cell_size = carry
|
||||||
|
pmid, disp, val = chunk
|
||||||
|
spatial_ndim = pmid.shape[1]
|
||||||
|
spatial_shape = mesh.shape
|
||||||
|
|
||||||
|
# multilinear mesh indices and fractions
|
||||||
|
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||||
|
spatial_shape)
|
||||||
|
# scatter
|
||||||
|
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||||
|
mesh = mesh.at[ind].add(val * frac)
|
||||||
|
|
||||||
|
carry = mesh, offset, cell_size
|
||||||
|
return carry, None
|
||||||
|
|
||||||
|
|
||||||
|
def scatter(pmid,
|
||||||
|
disp,
|
||||||
|
mesh,
|
||||||
|
chunk_size=2**24,
|
||||||
|
val=1.,
|
||||||
|
offset=0,
|
||||||
|
cell_size=1.):
|
||||||
|
|
||||||
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
|
val = jnp.asarray(val)
|
||||||
|
mesh = jnp.asarray(mesh)
|
||||||
|
|
||||||
|
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||||
|
carry = mesh.shape , mesh, offset, cell_size
|
||||||
|
if remainder is not None:
|
||||||
|
carry = _scatter_chunk(carry, remainder)[0]
|
||||||
|
carry = scan(_scatter_chunk, carry, chunks)[0]
|
||||||
|
mesh = carry[0]
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None):
|
||||||
|
"""Gather particle values from mesh multilinearly in n-D.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ptcl : Particles
|
||||||
|
conf : Configuration
|
||||||
|
mesh : ArrayLike
|
||||||
|
Input mesh.
|
||||||
|
val : ArrayLike, optional
|
||||||
|
Input values, can be 0D.
|
||||||
|
offset : ArrayLike, optional
|
||||||
|
Offset of mesh to particle grid. If 0D, the value is used in each dimension.
|
||||||
|
cell_size : float, optional
|
||||||
|
Mesh cell size in [L]. Default is ``conf.cell_size``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
val : jax.Array
|
||||||
|
Output values.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_cat(remainder_array, chunked_array):
|
||||||
|
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
||||||
|
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
||||||
|
|
||||||
|
if remainder_array is not None:
|
||||||
|
array = jnp.concatenate((remainder_array, array), axis=0)
|
||||||
|
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None):
|
||||||
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
|
|
||||||
|
mesh = jnp.asarray(mesh)
|
||||||
|
|
||||||
|
val = jnp.asarray(val)
|
||||||
|
|
||||||
|
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||||
|
raise ValueError('channel shape mismatch: '
|
||||||
|
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
||||||
|
|
||||||
|
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp,
|
||||||
|
val)
|
||||||
|
|
||||||
|
carry = mesh.shape , mesh, offset, cell_size
|
||||||
|
val_0 = None
|
||||||
|
if remainder is not None:
|
||||||
|
val_0 = _gather_chunk(carry, remainder)[1]
|
||||||
|
val = scan(_gather_chunk, carry, chunks)[1]
|
||||||
|
|
||||||
|
val = _chunk_cat(val_0, val)
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def _gather_chunk(carry, chunk):
|
||||||
|
mesh_shape , mesh, offset, cell_size = carry
|
||||||
|
pmid, disp, val = chunk
|
||||||
|
|
||||||
|
spatial_ndim = pmid.shape[1]
|
||||||
|
|
||||||
|
spatial_shape = mesh.shape[:spatial_ndim]
|
||||||
|
chan_ndim = mesh.ndim - spatial_ndim
|
||||||
|
chan_axis = tuple(range(-chan_ndim, 0))
|
||||||
|
|
||||||
|
# multilinear mesh indices and fractions
|
||||||
|
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset,
|
||||||
|
cell_size, spatial_shape, False)
|
||||||
|
|
||||||
|
# gather
|
||||||
|
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||||
|
frac = jnp.expand_dims(frac, chan_axis)
|
||||||
|
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
|
||||||
|
|
||||||
|
return carry, val
|
||||||
|
|
||||||
|
|
||||||
|
class CICPaintDXOperator(ShardedOperator):
|
||||||
|
|
||||||
|
name = 'cic_paint_dx'
|
||||||
|
|
||||||
|
def single_gpu_impl(displacement, halo_size=0):
|
||||||
|
|
||||||
|
del halo_size
|
||||||
|
|
||||||
|
original_shape = displacement.shape
|
||||||
|
|
||||||
|
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||||
|
|
||||||
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||||
|
jnp.arange(particle_mesh.shape[1]),
|
||||||
|
jnp.arange(particle_mesh.shape[2]),
|
||||||
|
indexing='ij')
|
||||||
|
|
||||||
|
pmid = jnp.stack([b, a, c], axis=-1)
|
||||||
|
pmid = pmid.reshape([-1, 3])
|
||||||
|
return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh)
|
||||||
|
|
||||||
|
def multi_gpu_impl(displacement, halo_size=0, __aux_input=None):
|
||||||
|
|
||||||
|
original_shape = displacement.shape
|
||||||
|
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||||
|
|
||||||
|
halo_tuple = (halo_size, halo_size)
|
||||||
|
if __aux_input[0] == 1:
|
||||||
|
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||||
|
elif __aux_input[1] == 1:
|
||||||
|
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||||
|
else:
|
||||||
|
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||||
|
|
||||||
|
particle_mesh = jnp.pad(particle_mesh, halo_width)
|
||||||
|
|
||||||
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||||
|
jnp.arange(particle_mesh.shape[1]),
|
||||||
|
jnp.arange(particle_mesh.shape[2]),
|
||||||
|
indexing='ij')
|
||||||
|
|
||||||
|
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
|
||||||
|
pmid = pmid.reshape([-1, 3])
|
||||||
|
return scatter(pmid, displacement.reshape([-1, 3]),
|
||||||
|
particle_mesh), halo_size
|
||||||
|
|
||||||
|
def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None):
|
||||||
|
|
||||||
|
if __aux_input[0] == 1:
|
||||||
|
halo_width = (0, halo_size, 0)
|
||||||
|
halo_extents = (0, halo_size // 2, 0)
|
||||||
|
elif __aux_input[1] == 1:
|
||||||
|
halo_width = (halo_size, 0, 0)
|
||||||
|
halo_extents = (halo_size // 2, 0, 0)
|
||||||
|
else:
|
||||||
|
halo_width = (halo_size, halo_size, 0)
|
||||||
|
halo_extents = (halo_size // 2, halo_size // 2, 0)
|
||||||
|
|
||||||
|
particle_mesh = halo_exchange(particle_mesh,
|
||||||
|
halo_extents=halo_extents,
|
||||||
|
halo_periods=(True, True, True))
|
||||||
|
particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width)
|
||||||
|
|
||||||
|
return particle_mesh
|
||||||
|
|
||||||
|
|
||||||
|
class CICReadDXOperator(ShardedOperator):
|
||||||
|
|
||||||
|
name = 'cic_read_dx'
|
||||||
|
|
||||||
|
def single_gpu_impl(particle_mesh, halo_size=0):
|
||||||
|
|
||||||
|
del halo_size
|
||||||
|
|
||||||
|
original_shape = (*particle_mesh.shape, 3)
|
||||||
|
|
||||||
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||||
|
jnp.arange(particle_mesh.shape[1]),
|
||||||
|
jnp.arange(particle_mesh.shape[2]),
|
||||||
|
indexing='ij')
|
||||||
|
|
||||||
|
pmid = jnp.stack([b, a, c], axis=-1)
|
||||||
|
pmid = pmid.reshape([-1, 3])
|
||||||
|
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
|
||||||
|
|
||||||
|
def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None):
|
||||||
|
|
||||||
|
halo_tuple = (halo_size, halo_size)
|
||||||
|
if __aux_input[0] == 1:
|
||||||
|
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||||
|
halo_extents = (0, halo_size // 2, 0)
|
||||||
|
elif __aux_input[1] == 1:
|
||||||
|
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||||
|
halo_extents = (halo_size // 2, 0, 0)
|
||||||
|
else:
|
||||||
|
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||||
|
halo_extents = (halo_size // 2, halo_size // 2, 0)
|
||||||
|
|
||||||
|
particle_mesh = slice_pad(particle_mesh, pad_width=halo_width)
|
||||||
|
particle_mesh = halo_exchange(particle_mesh,
|
||||||
|
halo_extents=halo_extents,
|
||||||
|
halo_periods=(True, True, True))
|
||||||
|
|
||||||
|
return particle_mesh, halo_size
|
||||||
|
|
||||||
|
|
||||||
|
def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None):
|
||||||
|
|
||||||
|
original_shape = (*particle_mesh.shape, 3)
|
||||||
|
halo_tuple = (halo_size, halo_size)
|
||||||
|
if __aux_input[0] == 1:
|
||||||
|
halo_width = ((0, 0), halo_tuple, (0, 0))
|
||||||
|
elif __aux_input[1] == 1:
|
||||||
|
halo_width = (halo_tuple, (0, 0), (0, 0))
|
||||||
|
else:
|
||||||
|
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
||||||
|
|
||||||
|
particle_mesh = jnp.pad(particle_mesh, halo_width)
|
||||||
|
|
||||||
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||||
|
jnp.arange(particle_mesh.shape[1]),
|
||||||
|
jnp.arange(particle_mesh.shape[2]),
|
||||||
|
indexing='ij')
|
||||||
|
|
||||||
|
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
|
||||||
|
pmid = pmid.reshape([-1, 3])
|
||||||
|
# TODO must be reshaped
|
||||||
|
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size
|
||||||
|
|
||||||
|
|
||||||
|
register_operator(CICPaintOperator)
|
||||||
|
register_operator(CICReadOperator)
|
||||||
|
register_operator(CICPaintDXOperator)
|
Loading…
Add table
Reference in a new issue