Adjust painting operators

This commit is contained in:
Wassim KABALAN 2024-07-09 02:34:49 +02:00
parent 6d8f130be7
commit d2fb1ee1e2
2 changed files with 96 additions and 59 deletions

View file

@ -4,6 +4,8 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import lax from jax import lax
from jax.lax import scan from jax.lax import scan
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxdecomp import halo_exchange from jaxdecomp import halo_exchange
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator, from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
@ -128,6 +130,13 @@ class CICPaintOperator(ShardedOperator):
return [get_axis_size(base_sharding, i) for i in range(2)] return [get_axis_size(base_sharding, i) for i in range(2)]
def infer_sharding_from_base_sharding(base_sharding):
in_specs = base_sharding.spec, base_sharding.spec, P()
out_specs = base_sharding.spec
return in_specs, out_specs
class CICReadOperator(ShardedOperator): class CICReadOperator(ShardedOperator):
@ -215,6 +224,24 @@ class CICReadOperator(ShardedOperator):
neighboor_coords[..., 3]] * kernel).sum(axis=-1) neighboor_coords[..., 3]] * kernel).sum(axis=-1)
return particles.reshape(original_shape) return particles.reshape(original_shape)
def get_aux_input_from_base_sharding(base_sharding):
def get_axis_size(sharding, index):
axis_name = sharding.spec[index]
if axis_name == None:
return 1
else:
return sharding.mesh.shape[sharding.spec[index]]
return [get_axis_size(base_sharding, i) for i in range(2)]
def infer_sharding_from_base_sharding(base_sharding):
in_specs = base_sharding.spec, base_sharding.spec, P()
out_specs = base_sharding.spec
return in_specs, out_specs
def _chunk_split(ptcl_num, chunk_size, *arrays): def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders """Split and reshape particle arrays into chunks and remainders, with the remainders
@ -305,7 +332,7 @@ def enmesh(i1, d1, a1, s1, b12, a2, s2):
def _scatter_chunk(carry, chunk): def _scatter_chunk(carry, chunk):
mesh_shape , mesh, offset, cell_size = carry mesh, offset, cell_size, mesh_shape = carry
pmid, disp, val = chunk pmid, disp, val = chunk
spatial_ndim = pmid.shape[1] spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape spatial_shape = mesh.shape
@ -317,7 +344,7 @@ def _scatter_chunk(carry, chunk):
ind = tuple(ind[..., i] for i in range(spatial_ndim)) ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(val * frac) mesh = mesh.at[ind].add(val * frac)
carry = mesh, offset, cell_size carry = mesh, offset, cell_size, mesh_shape
return carry, None return carry, None
@ -334,7 +361,7 @@ def scatter(pmid,
mesh = jnp.asarray(mesh) mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh.shape , mesh, offset, cell_size carry = mesh, offset, cell_size, mesh.shape
if remainder is not None: if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0] carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0] carry = scan(_scatter_chunk, carry, chunks)[0]
@ -342,31 +369,6 @@ def scatter(pmid,
return mesh return mesh
def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None):
"""Gather particle values from mesh multilinearly in n-D.
Parameters
----------
ptcl : Particles
conf : Configuration
mesh : ArrayLike
Input mesh.
val : ArrayLike, optional
Input values, can be 0D.
offset : ArrayLike, optional
Offset of mesh to particle grid. If 0D, the value is used in each dimension.
cell_size : float, optional
Mesh cell size in [L]. Default is ``conf.cell_size``.
Returns
-------
val : jax.Array
Output values.
"""
return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size)
def _chunk_cat(remainder_array, chunked_array): def _chunk_cat(remainder_array, chunked_array):
"""Reshape and concatenate one remainder and one chunked particle arrays.""" """Reshape and concatenate one remainder and one chunked particle arrays."""
array = chunked_array.reshape(-1, *chunked_array.shape[2:]) array = chunked_array.reshape(-1, *chunked_array.shape[2:])
@ -377,7 +379,7 @@ def _chunk_cat(remainder_array, chunked_array):
return array return array
def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None): def _gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape ptcl_num, spatial_ndim = pmid.shape
mesh = jnp.asarray(mesh) mesh = jnp.asarray(mesh)
@ -388,10 +390,9 @@ def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=Non
raise ValueError('channel shape mismatch: ' raise ValueError('channel shape mismatch: '
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}') f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
val)
carry = mesh.shape , mesh, offset, cell_size carry = mesh, offset, cell_size, mesh.shape
val_0 = None val_0 = None
if remainder is not None: if remainder is not None:
val_0 = _gather_chunk(carry, remainder)[1] val_0 = _gather_chunk(carry, remainder)[1]
@ -403,7 +404,7 @@ def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=Non
def _gather_chunk(carry, chunk): def _gather_chunk(carry, chunk):
mesh_shape , mesh, offset, cell_size = carry mesh, offset, cell_size, mesh_shape = carry
pmid, disp, val = chunk pmid, disp, val = chunk
spatial_ndim = pmid.shape[1] spatial_ndim = pmid.shape[1]
@ -413,8 +414,8 @@ def _gather_chunk(carry, chunk):
chan_axis = tuple(range(-chan_ndim, 0)) chan_axis = tuple(range(-chan_ndim, 0))
# multilinear mesh indices and fractions # multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
cell_size, spatial_shape, False) spatial_shape)
# gather # gather
ind = tuple(ind[..., i] for i in range(spatial_ndim)) ind = tuple(ind[..., i] for i in range(spatial_ndim))
@ -441,7 +442,7 @@ class CICPaintDXOperator(ShardedOperator):
jnp.arange(particle_mesh.shape[2]), jnp.arange(particle_mesh.shape[2]),
indexing='ij') indexing='ij')
pmid = jnp.stack([b, a, c], axis=-1) pmid = jnp.stack([a, b, c], axis=-1)
pmid = pmid.reshape([-1, 3]) pmid = pmid.reshape([-1, 3])
return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh) return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh)
@ -489,6 +490,24 @@ class CICPaintDXOperator(ShardedOperator):
return particle_mesh return particle_mesh
def get_aux_input_from_base_sharding(base_sharding):
def get_axis_size(sharding, index):
axis_name = sharding.spec[index]
if axis_name == None:
return 1
else:
return sharding.mesh.shape[sharding.spec[index]]
return [get_axis_size(base_sharding, i) for i in range(2)]
def infer_sharding_from_base_sharding(base_sharding):
in_specs = base_sharding.spec, P()
out_specs = base_sharding.spec
return in_specs, out_specs
class CICReadDXOperator(ShardedOperator): class CICReadDXOperator(ShardedOperator):
@ -498,7 +517,7 @@ class CICReadDXOperator(ShardedOperator):
del halo_size del halo_size
original_shape = (*particle_mesh.shape, 3) original_shape = particle_mesh.shape
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[1]),
@ -507,8 +526,10 @@ class CICReadDXOperator(ShardedOperator):
pmid = jnp.stack([b, a, c], axis=-1) pmid = jnp.stack([b, a, c], axis=-1)
pmid = pmid.reshape([-1, 3]) pmid = pmid.reshape([-1, 3])
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh) positions = _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
return positions.reshape(original_shape)
def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None): def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None):
halo_tuple = (halo_size, halo_size) halo_tuple = (halo_size, halo_size)
@ -528,20 +549,10 @@ class CICReadDXOperator(ShardedOperator):
halo_periods=(True, True, True)) halo_periods=(True, True, True))
return particle_mesh, halo_size return particle_mesh, halo_size
def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None):
original_shape = (*particle_mesh.shape, 3) def multi_gpu_impl(particle_mesh, halo_size, __aux_input=None):
halo_tuple = (halo_size, halo_size)
if __aux_input[0] == 1:
halo_width = ((0, 0), halo_tuple, (0, 0))
elif __aux_input[1] == 1:
halo_width = (halo_tuple, (0, 0), (0, 0))
else:
halo_width = (halo_tuple, halo_tuple, (0, 0))
particle_mesh = jnp.pad(particle_mesh, halo_width) original_shape = particle_mesh.shape
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[1]),
@ -550,10 +561,30 @@ class CICReadDXOperator(ShardedOperator):
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1) pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
pmid = pmid.reshape([-1, 3]) pmid = pmid.reshape([-1, 3])
# TODO must be reshaped positions = _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size
return positions.reshape(original_shape)
def get_aux_input_from_base_sharding(base_sharding):
def get_axis_size(sharding, index):
axis_name = sharding.spec[index]
if axis_name == None:
return 1
else:
return sharding.mesh.shape[sharding.spec[index]]
return [get_axis_size(base_sharding, i) for i in range(2)]
def infer_sharding_from_base_sharding(base_sharding):
in_specs = base_sharding.spec, P()
out_specs = base_sharding.spec
return in_specs, out_specs
register_operator(CICPaintOperator) register_operator(CICPaintOperator)
register_operator(CICReadOperator) register_operator(CICReadOperator)
register_operator(CICPaintDXOperator) register_operator(CICPaintDXOperator)
register_operator(CICReadDXOperator)

View file

@ -4,22 +4,28 @@ import jax.numpy as jnp
import jaxpm import jaxpm
import jaxpm.ops import jaxpm.ops
from jaxpm.kernels import cic_compensation, fftk from jaxpm._src.spmd_config import pm_operators
from jaxpm.kernels import cic_compensation
def cic_paint(particle_mesh, positions, halo_size=0): def cic_paint(particle_mesh, positions, halo_size=0):
return jaxpm.ops.cic_paint(particle_mesh, positions, halo_size=halo_size) return pm_operators.cic_paint(particle_mesh,
positions,
halo_size=halo_size)
def cic_read(mesh, positions, halo_size=0): def cic_read(mesh, positions, halo_size=0):
return jaxpm.ops.cic_read(mesh, positions, halo_size=halo_size) return pm_operators.cic_read(mesh, positions, halo_size=halo_size)
def cic_paint_dx(displacements, halo_size=0): def cic_paint_dx(displacements, halo_size=0):
return jaxpm.ops.cic_paint_dx(displacements, halo_size=halo_size) return pm_operators.cic_paint_dx(displacements, halo_size=halo_size)
def cic_read_dx(particle_mesh, halo_size=0):
return pm_operators.cic_read_dx(particle_mesh, halo_size=halo_size)
# TO REDO
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh """ Paints positions onto a 2d mesh
mesh: [nx, ny] mesh: [nx, ny]