from functools import partial import jax import jax.numpy as jnp from jax import lax from jax.lax import scan from jaxdecomp import halo_exchange from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator, ShardedOperator, register_operator) from jaxpm.ops import slice_pad, slice_unpad class CICPaintOperator(ShardedOperator): name = 'cic_paint' def single_gpu_impl(particle_mesh: jnp.ndarray, positions: jnp.ndarray, halo_size=0): del halo_size positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] neighboor_coords_mod = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), jnp.array(particle_mesh.shape)) dnums = jax.lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod, kernel.reshape([-1, 8]), dnums) return particle_mesh def multi_gpu_impl(particle_mesh: jnp.ndarray, positions: jnp.ndarray, halo_size=8, __aux_input=None): rank = jax.process_index() correct_y = -particle_mesh.shape[1] * (rank // __aux_input[0]) correct_z = -particle_mesh.shape[0] * (rank % __aux_input[1]) # Get positions relative to the start of each slice positions = positions.at[:, :, :, 1].add(correct_y) positions = positions.at[:, :, :, 0].add(correct_z) positions = positions.reshape([-1, 3]) halo_tuple = (halo_size, halo_size) if __aux_input[0] == 1: halo_width = ((0, 0), halo_tuple, (0, 0)) halo_start = [0, halo_size, 0] elif __aux_input[1] == 1: halo_width = (halo_tuple, (0, 0), (0, 0)) halo_start = [halo_size, 0, 0] else: halo_width = (halo_tuple, halo_tuple, (0, 0)) halo_start = [halo_size, halo_size, 0] particle_mesh = jnp.pad(particle_mesh, halo_width) positions += jnp.array(halo_start).reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] neighboor_coords_mod = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), jnp.array(particle_mesh.shape)) dnums = jax.lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) particle_mesh = lax.scatter_add(particle_mesh, neighboor_coords_mod, kernel.reshape([-1, 8]), dnums) return particle_mesh, halo_size def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None): if __aux_input[0] == 1: halo_width = (0, halo_size, 0) halo_extents = (0, halo_size // 2, 0) elif __aux_input[1] == 1: halo_width = (halo_size, 0, 0) halo_extents = (halo_size // 2, 0, 0) else: halo_width = (halo_size, halo_size, 0) halo_extents = (halo_size // 2, halo_size // 2, 0) particle_mesh = halo_exchange(particle_mesh, halo_extents=halo_extents, halo_periods=(True, True, True)) particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width) return particle_mesh def get_aux_input_from_base_sharding(base_sharding): def get_axis_size(sharding, index): axis_name = sharding.spec[index] if axis_name == None: return 1 else: return sharding.mesh.shape[sharding.spec[index]] return [get_axis_size(base_sharding, i) for i in range(2)] class CICReadOperator(ShardedOperator): name = 'cic_read' def single_gpu_impl(particle_mesh: jnp.ndarray, positions: jnp.ndarray, halo_size=0): del halo_size original_shape = positions.shape positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(particle_mesh.shape)) particles = ( particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], neighboor_coords[..., 3]] * kernel).sum(axis=-1) return particles.reshape(original_shape) def multi_gpu_prolog(particle_mesh: jnp.ndarray, positions: jnp.ndarray, halo_size=0, __aux_input=None): halo_tuple = (halo_size, halo_size) if __aux_input[0] == 1: halo_width = ((0, 0), halo_tuple, (0, 0)) halo_extents = (0, halo_size // 2, 0) elif __aux_input[1] == 1: halo_width = (halo_tuple, (0, 0), (0, 0)) halo_extents = (halo_size // 2, 0, 0) else: halo_width = (halo_tuple, halo_tuple, (0, 0)) halo_extents = (halo_size // 2, halo_size // 2, 0) particle_mesh = slice_pad(particle_mesh, pad_width=halo_width) particle_mesh = halo_exchange(particle_mesh, halo_extents=halo_extents, halo_periods=(True, True, True)) return particle_mesh, positions, halo_size def multi_gpu_impl(particle_mesh: jnp.ndarray, positions: jnp.ndarray, halo_size=0, __aux_input=None): original_shape = positions.shape positions = positions.reshape([-1, 3]) if __aux_input[0] == 1: halo_start = [0, halo_size, 0] elif __aux_input[1] == 1: halo_start = [halo_size, 0, 0] else: halo_start = [halo_size, halo_size, 0] positions += jnp.array(halo_start).reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(particle_mesh.shape)) particles = ( particle_mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], neighboor_coords[..., 3]] * kernel).sum(axis=-1) return particles.reshape(original_shape) def _chunk_split(ptcl_num, chunk_size, *arrays): """Split and reshape particle arrays into chunks and remainders, with the remainders preceding the chunks. 0D ones are duplicated as full arrays in the chunks.""" chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num) remainder_size = ptcl_num % chunk_size chunk_num = ptcl_num // chunk_size remainder = None chunks = arrays if remainder_size: remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays] chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays] # `scan` triggers errors in scatter and gather without the `full` chunks = [ x.reshape(chunk_num, chunk_size, *x.shape[1:]) if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks ] return remainder, chunks def enmesh(i1, d1, a1, s1, b12, a2, s2): """Multilinear enmeshing.""" i1 = jnp.asarray(i1) d1 = jnp.asarray(d1) a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, dtype=d1.dtype) if s1 is not None: s1 = jnp.array(s1, dtype=i1.dtype) b12 = jnp.float64(b12) if a2 is not None: a2 = jnp.float64(a2) if s2 is not None: s2 = jnp.array(s2, dtype=i1.dtype) dim = i1.shape[1] neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> jnp.arange(dim, dtype=i1.dtype)) & 1 if a2 is not None: P = i1 * a1 + d1 - b12 P = P[:, jnp.newaxis] # insert neighbor axis i2 = P + neighbors * a2 # multilinear if s1 is not None: L = s1 * a1 i2 %= L i2 //= a2 d2 = P - i2 * a2 if s1 is not None: d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected i2 = i2.astype(i1.dtype) d2 = d2.astype(d1.dtype) a2 = a2.astype(d1.dtype) d2 /= a2 else: i12, d12 = jnp.divmod(b12, a1) i1 -= i12.astype(i1.dtype) d1 -= d12.astype(d1.dtype) # insert neighbor axis i1 = i1[:, jnp.newaxis] d1 = d1[:, jnp.newaxis] # multilinear d1 /= a1 i2 = jnp.floor(d1).astype(i1.dtype) i2 += neighbors d2 = d1 - i2 i2 += i1 if s1 is not None: i2 %= s1 f2 = 1 - jnp.abs(d2) if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None i2 = jnp.where(i2 < 0, s2, i2) f2 = f2.prod(axis=-1) return i2, f2 def _scatter_chunk(carry, chunk): mesh_shape , mesh, offset, cell_size = carry pmid, disp, val = chunk spatial_ndim = pmid.shape[1] spatial_shape = mesh.shape # multilinear mesh indices and fractions ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, spatial_shape) # scatter ind = tuple(ind[..., i] for i in range(spatial_ndim)) mesh = mesh.at[ind].add(val * frac) carry = mesh, offset, cell_size return carry, None def scatter(pmid, disp, mesh, chunk_size=2**24, val=1., offset=0, cell_size=1.): ptcl_num, spatial_ndim = pmid.shape val = jnp.asarray(val) mesh = jnp.asarray(mesh) remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) carry = mesh.shape , mesh, offset, cell_size if remainder is not None: carry = _scatter_chunk(carry, remainder)[0] carry = scan(_scatter_chunk, carry, chunks)[0] mesh = carry[0] return mesh def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None): """Gather particle values from mesh multilinearly in n-D. Parameters ---------- ptcl : Particles conf : Configuration mesh : ArrayLike Input mesh. val : ArrayLike, optional Input values, can be 0D. offset : ArrayLike, optional Offset of mesh to particle grid. If 0D, the value is used in each dimension. cell_size : float, optional Mesh cell size in [L]. Default is ``conf.cell_size``. Returns ------- val : jax.Array Output values. """ return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size) def _chunk_cat(remainder_array, chunked_array): """Reshape and concatenate one remainder and one chunked particle arrays.""" array = chunked_array.reshape(-1, *chunked_array.shape[2:]) if remainder_array is not None: array = jnp.concatenate((remainder_array, array), axis=0) return array def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None): ptcl_num, spatial_ndim = pmid.shape mesh = jnp.asarray(mesh) val = jnp.asarray(val) if mesh.shape[spatial_ndim:] != val.shape[1:]: raise ValueError('channel shape mismatch: ' f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}') remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) carry = mesh.shape , mesh, offset, cell_size val_0 = None if remainder is not None: val_0 = _gather_chunk(carry, remainder)[1] val = scan(_gather_chunk, carry, chunks)[1] val = _chunk_cat(val_0, val) return val def _gather_chunk(carry, chunk): mesh_shape , mesh, offset, cell_size = carry pmid, disp, val = chunk spatial_ndim = pmid.shape[1] spatial_shape = mesh.shape[:spatial_ndim] chan_ndim = mesh.ndim - spatial_ndim chan_axis = tuple(range(-chan_ndim, 0)) # multilinear mesh indices and fractions ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, spatial_shape, False) # gather ind = tuple(ind[..., i] for i in range(spatial_ndim)) frac = jnp.expand_dims(frac, chan_axis) val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1) return carry, val class CICPaintDXOperator(ShardedOperator): name = 'cic_paint_dx' def single_gpu_impl(displacement, halo_size=0): del halo_size original_shape = displacement.shape particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[2]), indexing='ij') pmid = jnp.stack([b, a, c], axis=-1) pmid = pmid.reshape([-1, 3]) return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh) def multi_gpu_impl(displacement, halo_size=0, __aux_input=None): original_shape = displacement.shape particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') halo_tuple = (halo_size, halo_size) if __aux_input[0] == 1: halo_width = ((0, 0), halo_tuple, (0, 0)) elif __aux_input[1] == 1: halo_width = (halo_tuple, (0, 0), (0, 0)) else: halo_width = (halo_tuple, halo_tuple, (0, 0)) particle_mesh = jnp.pad(particle_mesh, halo_width) a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[2]), indexing='ij') pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1) pmid = pmid.reshape([-1, 3]) return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh), halo_size def multi_gpu_epilog(particle_mesh, halo_size, __aux_input=None): if __aux_input[0] == 1: halo_width = (0, halo_size, 0) halo_extents = (0, halo_size // 2, 0) elif __aux_input[1] == 1: halo_width = (halo_size, 0, 0) halo_extents = (halo_size // 2, 0, 0) else: halo_width = (halo_size, halo_size, 0) halo_extents = (halo_size // 2, halo_size // 2, 0) particle_mesh = halo_exchange(particle_mesh, halo_extents=halo_extents, halo_periods=(True, True, True)) particle_mesh = slice_unpad(particle_mesh, pad_width=halo_width) return particle_mesh class CICReadDXOperator(ShardedOperator): name = 'cic_read_dx' def single_gpu_impl(particle_mesh, halo_size=0): del halo_size original_shape = (*particle_mesh.shape, 3) a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[2]), indexing='ij') pmid = jnp.stack([b, a, c], axis=-1) pmid = pmid.reshape([-1, 3]) return _gather(pmid, jnp.zeros_like(pmid), particle_mesh) def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None): halo_tuple = (halo_size, halo_size) if __aux_input[0] == 1: halo_width = ((0, 0), halo_tuple, (0, 0)) halo_extents = (0, halo_size // 2, 0) elif __aux_input[1] == 1: halo_width = (halo_tuple, (0, 0), (0, 0)) halo_extents = (halo_size // 2, 0, 0) else: halo_width = (halo_tuple, halo_tuple, (0, 0)) halo_extents = (halo_size // 2, halo_size // 2, 0) particle_mesh = slice_pad(particle_mesh, pad_width=halo_width) particle_mesh = halo_exchange(particle_mesh, halo_extents=halo_extents, halo_periods=(True, True, True)) return particle_mesh, halo_size def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None): original_shape = (*particle_mesh.shape, 3) halo_tuple = (halo_size, halo_size) if __aux_input[0] == 1: halo_width = ((0, 0), halo_tuple, (0, 0)) elif __aux_input[1] == 1: halo_width = (halo_tuple, (0, 0), (0, 0)) else: halo_width = (halo_tuple, halo_tuple, (0, 0)) particle_mesh = jnp.pad(particle_mesh, halo_width) a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[2]), indexing='ij') pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1) pmid = pmid.reshape([-1, 3]) # TODO must be reshaped return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size register_operator(CICPaintOperator) register_operator(CICReadOperator) register_operator(CICPaintDXOperator)