From d2fb1ee1e2ff828e98933417b5cf41efaebcfde5 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 9 Jul 2024 02:34:49 +0200 Subject: [PATCH] Adjust painting operators --- jaxpm/_src/painting_ops.py | 139 +++++++++++++++++++++++-------------- jaxpm/painting.py | 16 +++-- 2 files changed, 96 insertions(+), 59 deletions(-) diff --git a/jaxpm/_src/painting_ops.py b/jaxpm/_src/painting_ops.py index 7556d37..af3c441 100644 --- a/jaxpm/_src/painting_ops.py +++ b/jaxpm/_src/painting_ops.py @@ -4,6 +4,8 @@ import jax import jax.numpy as jnp from jax import lax from jax.lax import scan +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P from jaxdecomp import halo_exchange from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator, @@ -128,6 +130,13 @@ class CICPaintOperator(ShardedOperator): return [get_axis_size(base_sharding, i) for i in range(2)] + def infer_sharding_from_base_sharding(base_sharding): + + in_specs = base_sharding.spec, base_sharding.spec, P() + out_specs = base_sharding.spec + + return in_specs, out_specs + class CICReadOperator(ShardedOperator): @@ -215,6 +224,24 @@ class CICReadOperator(ShardedOperator): neighboor_coords[..., 3]] * kernel).sum(axis=-1) return particles.reshape(original_shape) + def get_aux_input_from_base_sharding(base_sharding): + + def get_axis_size(sharding, index): + axis_name = sharding.spec[index] + if axis_name == None: + return 1 + else: + return sharding.mesh.shape[sharding.spec[index]] + + return [get_axis_size(base_sharding, i) for i in range(2)] + + def infer_sharding_from_base_sharding(base_sharding): + + in_specs = base_sharding.spec, base_sharding.spec, P() + out_specs = base_sharding.spec + + return in_specs, out_specs + def _chunk_split(ptcl_num, chunk_size, *arrays): """Split and reshape particle arrays into chunks and remainders, with the remainders @@ -305,7 +332,7 @@ def enmesh(i1, d1, a1, s1, b12, a2, s2): def _scatter_chunk(carry, chunk): - mesh_shape , mesh, offset, cell_size = carry + mesh, offset, cell_size, mesh_shape = carry pmid, disp, val = chunk spatial_ndim = pmid.shape[1] spatial_shape = mesh.shape @@ -317,7 +344,7 @@ def _scatter_chunk(carry, chunk): ind = tuple(ind[..., i] for i in range(spatial_ndim)) mesh = mesh.at[ind].add(val * frac) - carry = mesh, offset, cell_size + carry = mesh, offset, cell_size, mesh_shape return carry, None @@ -334,7 +361,7 @@ def scatter(pmid, mesh = jnp.asarray(mesh) remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) - carry = mesh.shape , mesh, offset, cell_size + carry = mesh, offset, cell_size, mesh.shape if remainder is not None: carry = _scatter_chunk(carry, remainder)[0] carry = scan(_scatter_chunk, carry, chunks)[0] @@ -342,31 +369,6 @@ def scatter(pmid, return mesh -def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None): - """Gather particle values from mesh multilinearly in n-D. - - Parameters - ---------- - ptcl : Particles - conf : Configuration - mesh : ArrayLike - Input mesh. - val : ArrayLike, optional - Input values, can be 0D. - offset : ArrayLike, optional - Offset of mesh to particle grid. If 0D, the value is used in each dimension. - cell_size : float, optional - Mesh cell size in [L]. Default is ``conf.cell_size``. - - Returns - ------- - val : jax.Array - Output values. - - """ - return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size) - - def _chunk_cat(remainder_array, chunked_array): """Reshape and concatenate one remainder and one chunked particle arrays.""" array = chunked_array.reshape(-1, *chunked_array.shape[2:]) @@ -377,7 +379,7 @@ def _chunk_cat(remainder_array, chunked_array): return array -def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None): +def _gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.): ptcl_num, spatial_ndim = pmid.shape mesh = jnp.asarray(mesh) @@ -388,10 +390,9 @@ def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=Non raise ValueError('channel shape mismatch: ' f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}') - remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, - val) + remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) - carry = mesh.shape , mesh, offset, cell_size + carry = mesh, offset, cell_size, mesh.shape val_0 = None if remainder is not None: val_0 = _gather_chunk(carry, remainder)[1] @@ -403,7 +404,7 @@ def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=Non def _gather_chunk(carry, chunk): - mesh_shape , mesh, offset, cell_size = carry + mesh, offset, cell_size, mesh_shape = carry pmid, disp, val = chunk spatial_ndim = pmid.shape[1] @@ -413,8 +414,8 @@ def _gather_chunk(carry, chunk): chan_axis = tuple(range(-chan_ndim, 0)) # multilinear mesh indices and fractions - ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, - cell_size, spatial_shape, False) + ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, + spatial_shape) # gather ind = tuple(ind[..., i] for i in range(spatial_ndim)) @@ -441,7 +442,7 @@ class CICPaintDXOperator(ShardedOperator): jnp.arange(particle_mesh.shape[2]), indexing='ij') - pmid = jnp.stack([b, a, c], axis=-1) + pmid = jnp.stack([a, b, c], axis=-1) pmid = pmid.reshape([-1, 3]) return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh) @@ -489,6 +490,24 @@ class CICPaintDXOperator(ShardedOperator): return particle_mesh + def get_aux_input_from_base_sharding(base_sharding): + + def get_axis_size(sharding, index): + axis_name = sharding.spec[index] + if axis_name == None: + return 1 + else: + return sharding.mesh.shape[sharding.spec[index]] + + return [get_axis_size(base_sharding, i) for i in range(2)] + + def infer_sharding_from_base_sharding(base_sharding): + + in_specs = base_sharding.spec, P() + out_specs = base_sharding.spec + + return in_specs, out_specs + class CICReadDXOperator(ShardedOperator): @@ -498,7 +517,7 @@ class CICReadDXOperator(ShardedOperator): del halo_size - original_shape = (*particle_mesh.shape, 3) + original_shape = particle_mesh.shape a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), jnp.arange(particle_mesh.shape[1]), @@ -507,8 +526,10 @@ class CICReadDXOperator(ShardedOperator): pmid = jnp.stack([b, a, c], axis=-1) pmid = pmid.reshape([-1, 3]) - return _gather(pmid, jnp.zeros_like(pmid), particle_mesh) - + positions = _gather(pmid, jnp.zeros_like(pmid), particle_mesh) + + return positions.reshape(original_shape) + def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None): halo_tuple = (halo_size, halo_size) @@ -528,20 +549,10 @@ class CICReadDXOperator(ShardedOperator): halo_periods=(True, True, True)) return particle_mesh, halo_size - - - def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None): - original_shape = (*particle_mesh.shape, 3) - halo_tuple = (halo_size, halo_size) - if __aux_input[0] == 1: - halo_width = ((0, 0), halo_tuple, (0, 0)) - elif __aux_input[1] == 1: - halo_width = (halo_tuple, (0, 0), (0, 0)) - else: - halo_width = (halo_tuple, halo_tuple, (0, 0)) + def multi_gpu_impl(particle_mesh, halo_size, __aux_input=None): - particle_mesh = jnp.pad(particle_mesh, halo_width) + original_shape = particle_mesh.shape a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), jnp.arange(particle_mesh.shape[1]), @@ -550,10 +561,30 @@ class CICReadDXOperator(ShardedOperator): pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1) pmid = pmid.reshape([-1, 3]) - # TODO must be reshaped - return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size - + positions = _gather(pmid, jnp.zeros_like(pmid), particle_mesh) + + return positions.reshape(original_shape) + + def get_aux_input_from_base_sharding(base_sharding): + + def get_axis_size(sharding, index): + axis_name = sharding.spec[index] + if axis_name == None: + return 1 + else: + return sharding.mesh.shape[sharding.spec[index]] + + return [get_axis_size(base_sharding, i) for i in range(2)] + + def infer_sharding_from_base_sharding(base_sharding): + + in_specs = base_sharding.spec, P() + out_specs = base_sharding.spec + + return in_specs, out_specs + register_operator(CICPaintOperator) register_operator(CICReadOperator) register_operator(CICPaintDXOperator) +register_operator(CICReadDXOperator) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 108838c..5631e73 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -4,22 +4,28 @@ import jax.numpy as jnp import jaxpm import jaxpm.ops -from jaxpm.kernels import cic_compensation, fftk +from jaxpm._src.spmd_config import pm_operators +from jaxpm.kernels import cic_compensation def cic_paint(particle_mesh, positions, halo_size=0): - return jaxpm.ops.cic_paint(particle_mesh, positions, halo_size=halo_size) + return pm_operators.cic_paint(particle_mesh, + positions, + halo_size=halo_size) def cic_read(mesh, positions, halo_size=0): - return jaxpm.ops.cic_read(mesh, positions, halo_size=halo_size) + return pm_operators.cic_read(mesh, positions, halo_size=halo_size) def cic_paint_dx(displacements, halo_size=0): - return jaxpm.ops.cic_paint_dx(displacements, halo_size=halo_size) + return pm_operators.cic_paint_dx(displacements, halo_size=halo_size) + + +def cic_read_dx(particle_mesh, halo_size=0): + return pm_operators.cic_read_dx(particle_mesh, halo_size=halo_size) -# TO REDO def cic_paint_2d(mesh, positions, weight): """ Paints positions onto a 2d mesh mesh: [nx, ny]