mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 19:50:55 +00:00
Adjust painting operators
This commit is contained in:
parent
6d8f130be7
commit
d2fb1ee1e2
2 changed files with 96 additions and 59 deletions
|
@ -4,6 +4,8 @@ import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax import lax
|
from jax import lax
|
||||||
from jax.lax import scan
|
from jax.lax import scan
|
||||||
|
from jax.sharding import NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
from jaxdecomp import halo_exchange
|
from jaxdecomp import halo_exchange
|
||||||
|
|
||||||
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
|
from jaxpm._src.spmd_config import (CallBackOperator, CustomPartionedOperator,
|
||||||
|
@ -128,6 +130,13 @@ class CICPaintOperator(ShardedOperator):
|
||||||
|
|
||||||
return [get_axis_size(base_sharding, i) for i in range(2)]
|
return [get_axis_size(base_sharding, i) for i in range(2)]
|
||||||
|
|
||||||
|
def infer_sharding_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
in_specs = base_sharding.spec, base_sharding.spec, P()
|
||||||
|
out_specs = base_sharding.spec
|
||||||
|
|
||||||
|
return in_specs, out_specs
|
||||||
|
|
||||||
|
|
||||||
class CICReadOperator(ShardedOperator):
|
class CICReadOperator(ShardedOperator):
|
||||||
|
|
||||||
|
@ -215,6 +224,24 @@ class CICReadOperator(ShardedOperator):
|
||||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||||
return particles.reshape(original_shape)
|
return particles.reshape(original_shape)
|
||||||
|
|
||||||
|
def get_aux_input_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
def get_axis_size(sharding, index):
|
||||||
|
axis_name = sharding.spec[index]
|
||||||
|
if axis_name == None:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return sharding.mesh.shape[sharding.spec[index]]
|
||||||
|
|
||||||
|
return [get_axis_size(base_sharding, i) for i in range(2)]
|
||||||
|
|
||||||
|
def infer_sharding_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
in_specs = base_sharding.spec, base_sharding.spec, P()
|
||||||
|
out_specs = base_sharding.spec
|
||||||
|
|
||||||
|
return in_specs, out_specs
|
||||||
|
|
||||||
|
|
||||||
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||||
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
||||||
|
@ -305,7 +332,7 @@ def enmesh(i1, d1, a1, s1, b12, a2, s2):
|
||||||
|
|
||||||
|
|
||||||
def _scatter_chunk(carry, chunk):
|
def _scatter_chunk(carry, chunk):
|
||||||
mesh_shape , mesh, offset, cell_size = carry
|
mesh, offset, cell_size, mesh_shape = carry
|
||||||
pmid, disp, val = chunk
|
pmid, disp, val = chunk
|
||||||
spatial_ndim = pmid.shape[1]
|
spatial_ndim = pmid.shape[1]
|
||||||
spatial_shape = mesh.shape
|
spatial_shape = mesh.shape
|
||||||
|
@ -317,7 +344,7 @@ def _scatter_chunk(carry, chunk):
|
||||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||||
mesh = mesh.at[ind].add(val * frac)
|
mesh = mesh.at[ind].add(val * frac)
|
||||||
|
|
||||||
carry = mesh, offset, cell_size
|
carry = mesh, offset, cell_size, mesh_shape
|
||||||
return carry, None
|
return carry, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -334,7 +361,7 @@ def scatter(pmid,
|
||||||
mesh = jnp.asarray(mesh)
|
mesh = jnp.asarray(mesh)
|
||||||
|
|
||||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||||
carry = mesh.shape , mesh, offset, cell_size
|
carry = mesh, offset, cell_size, mesh.shape
|
||||||
if remainder is not None:
|
if remainder is not None:
|
||||||
carry = _scatter_chunk(carry, remainder)[0]
|
carry = _scatter_chunk(carry, remainder)[0]
|
||||||
carry = scan(_scatter_chunk, carry, chunks)[0]
|
carry = scan(_scatter_chunk, carry, chunks)[0]
|
||||||
|
@ -342,31 +369,6 @@ def scatter(pmid,
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
def gather(ptcl, conf, mesh, val=1, offset=0, cell_size=None):
|
|
||||||
"""Gather particle values from mesh multilinearly in n-D.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
ptcl : Particles
|
|
||||||
conf : Configuration
|
|
||||||
mesh : ArrayLike
|
|
||||||
Input mesh.
|
|
||||||
val : ArrayLike, optional
|
|
||||||
Input values, can be 0D.
|
|
||||||
offset : ArrayLike, optional
|
|
||||||
Offset of mesh to particle grid. If 0D, the value is used in each dimension.
|
|
||||||
cell_size : float, optional
|
|
||||||
Mesh cell size in [L]. Default is ``conf.cell_size``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
val : jax.Array
|
|
||||||
Output values.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return _gather(ptcl.pmid, ptcl.disp, conf, mesh, val, offset, cell_size)
|
|
||||||
|
|
||||||
|
|
||||||
def _chunk_cat(remainder_array, chunked_array):
|
def _chunk_cat(remainder_array, chunked_array):
|
||||||
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
||||||
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
||||||
|
@ -377,7 +379,7 @@ def _chunk_cat(remainder_array, chunked_array):
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=None):
|
def _gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.):
|
||||||
ptcl_num, spatial_ndim = pmid.shape
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
|
|
||||||
mesh = jnp.asarray(mesh)
|
mesh = jnp.asarray(mesh)
|
||||||
|
@ -388,10 +390,9 @@ def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=Non
|
||||||
raise ValueError('channel shape mismatch: '
|
raise ValueError('channel shape mismatch: '
|
||||||
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
||||||
|
|
||||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp,
|
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||||
val)
|
|
||||||
|
|
||||||
carry = mesh.shape , mesh, offset, cell_size
|
carry = mesh, offset, cell_size, mesh.shape
|
||||||
val_0 = None
|
val_0 = None
|
||||||
if remainder is not None:
|
if remainder is not None:
|
||||||
val_0 = _gather_chunk(carry, remainder)[1]
|
val_0 = _gather_chunk(carry, remainder)[1]
|
||||||
|
@ -403,7 +404,7 @@ def _gather(pmid, disp, mesh , chunk_size=2**24, val=1, offset=0, cell_size=Non
|
||||||
|
|
||||||
|
|
||||||
def _gather_chunk(carry, chunk):
|
def _gather_chunk(carry, chunk):
|
||||||
mesh_shape , mesh, offset, cell_size = carry
|
mesh, offset, cell_size, mesh_shape = carry
|
||||||
pmid, disp, val = chunk
|
pmid, disp, val = chunk
|
||||||
|
|
||||||
spatial_ndim = pmid.shape[1]
|
spatial_ndim = pmid.shape[1]
|
||||||
|
@ -413,8 +414,8 @@ def _gather_chunk(carry, chunk):
|
||||||
chan_axis = tuple(range(-chan_ndim, 0))
|
chan_axis = tuple(range(-chan_ndim, 0))
|
||||||
|
|
||||||
# multilinear mesh indices and fractions
|
# multilinear mesh indices and fractions
|
||||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset,
|
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||||
cell_size, spatial_shape, False)
|
spatial_shape)
|
||||||
|
|
||||||
# gather
|
# gather
|
||||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||||
|
@ -441,7 +442,7 @@ class CICPaintDXOperator(ShardedOperator):
|
||||||
jnp.arange(particle_mesh.shape[2]),
|
jnp.arange(particle_mesh.shape[2]),
|
||||||
indexing='ij')
|
indexing='ij')
|
||||||
|
|
||||||
pmid = jnp.stack([b, a, c], axis=-1)
|
pmid = jnp.stack([a, b, c], axis=-1)
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh)
|
return scatter(pmid, displacement.reshape([-1, 3]), particle_mesh)
|
||||||
|
|
||||||
|
@ -489,6 +490,24 @@ class CICPaintDXOperator(ShardedOperator):
|
||||||
|
|
||||||
return particle_mesh
|
return particle_mesh
|
||||||
|
|
||||||
|
def get_aux_input_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
def get_axis_size(sharding, index):
|
||||||
|
axis_name = sharding.spec[index]
|
||||||
|
if axis_name == None:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return sharding.mesh.shape[sharding.spec[index]]
|
||||||
|
|
||||||
|
return [get_axis_size(base_sharding, i) for i in range(2)]
|
||||||
|
|
||||||
|
def infer_sharding_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
in_specs = base_sharding.spec, P()
|
||||||
|
out_specs = base_sharding.spec
|
||||||
|
|
||||||
|
return in_specs, out_specs
|
||||||
|
|
||||||
|
|
||||||
class CICReadDXOperator(ShardedOperator):
|
class CICReadDXOperator(ShardedOperator):
|
||||||
|
|
||||||
|
@ -498,7 +517,7 @@ class CICReadDXOperator(ShardedOperator):
|
||||||
|
|
||||||
del halo_size
|
del halo_size
|
||||||
|
|
||||||
original_shape = (*particle_mesh.shape, 3)
|
original_shape = particle_mesh.shape
|
||||||
|
|
||||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||||
jnp.arange(particle_mesh.shape[1]),
|
jnp.arange(particle_mesh.shape[1]),
|
||||||
|
@ -507,7 +526,9 @@ class CICReadDXOperator(ShardedOperator):
|
||||||
|
|
||||||
pmid = jnp.stack([b, a, c], axis=-1)
|
pmid = jnp.stack([b, a, c], axis=-1)
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
|
positions = _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
|
||||||
|
|
||||||
|
return positions.reshape(original_shape)
|
||||||
|
|
||||||
def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None):
|
def multi_gpu_prolog(particle_mesh, halo_size=0, __aux_input=None):
|
||||||
|
|
||||||
|
@ -529,19 +550,9 @@ class CICReadDXOperator(ShardedOperator):
|
||||||
|
|
||||||
return particle_mesh, halo_size
|
return particle_mesh, halo_size
|
||||||
|
|
||||||
|
def multi_gpu_impl(particle_mesh, halo_size, __aux_input=None):
|
||||||
|
|
||||||
def multi_gpu_impl(particle_mesh, halo_size=0, __aux_input=None):
|
original_shape = particle_mesh.shape
|
||||||
|
|
||||||
original_shape = (*particle_mesh.shape, 3)
|
|
||||||
halo_tuple = (halo_size, halo_size)
|
|
||||||
if __aux_input[0] == 1:
|
|
||||||
halo_width = ((0, 0), halo_tuple, (0, 0))
|
|
||||||
elif __aux_input[1] == 1:
|
|
||||||
halo_width = (halo_tuple, (0, 0), (0, 0))
|
|
||||||
else:
|
|
||||||
halo_width = (halo_tuple, halo_tuple, (0, 0))
|
|
||||||
|
|
||||||
particle_mesh = jnp.pad(particle_mesh, halo_width)
|
|
||||||
|
|
||||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||||
jnp.arange(particle_mesh.shape[1]),
|
jnp.arange(particle_mesh.shape[1]),
|
||||||
|
@ -550,10 +561,30 @@ class CICReadDXOperator(ShardedOperator):
|
||||||
|
|
||||||
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
|
pmid = jnp.stack([b + halo_size, a + halo_size, c], axis=-1)
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
# TODO must be reshaped
|
positions = _gather(pmid, jnp.zeros_like(pmid), particle_mesh)
|
||||||
return _gather(pmid, jnp.zeros_like(pmid), particle_mesh), halo_size
|
|
||||||
|
return positions.reshape(original_shape)
|
||||||
|
|
||||||
|
def get_aux_input_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
def get_axis_size(sharding, index):
|
||||||
|
axis_name = sharding.spec[index]
|
||||||
|
if axis_name == None:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return sharding.mesh.shape[sharding.spec[index]]
|
||||||
|
|
||||||
|
return [get_axis_size(base_sharding, i) for i in range(2)]
|
||||||
|
|
||||||
|
def infer_sharding_from_base_sharding(base_sharding):
|
||||||
|
|
||||||
|
in_specs = base_sharding.spec, P()
|
||||||
|
out_specs = base_sharding.spec
|
||||||
|
|
||||||
|
return in_specs, out_specs
|
||||||
|
|
||||||
|
|
||||||
register_operator(CICPaintOperator)
|
register_operator(CICPaintOperator)
|
||||||
register_operator(CICReadOperator)
|
register_operator(CICReadOperator)
|
||||||
register_operator(CICPaintDXOperator)
|
register_operator(CICPaintDXOperator)
|
||||||
|
register_operator(CICReadDXOperator)
|
||||||
|
|
|
@ -4,22 +4,28 @@ import jax.numpy as jnp
|
||||||
|
|
||||||
import jaxpm
|
import jaxpm
|
||||||
import jaxpm.ops
|
import jaxpm.ops
|
||||||
from jaxpm.kernels import cic_compensation, fftk
|
from jaxpm._src.spmd_config import pm_operators
|
||||||
|
from jaxpm.kernels import cic_compensation
|
||||||
|
|
||||||
|
|
||||||
def cic_paint(particle_mesh, positions, halo_size=0):
|
def cic_paint(particle_mesh, positions, halo_size=0):
|
||||||
return jaxpm.ops.cic_paint(particle_mesh, positions, halo_size=halo_size)
|
return pm_operators.cic_paint(particle_mesh,
|
||||||
|
positions,
|
||||||
|
halo_size=halo_size)
|
||||||
|
|
||||||
|
|
||||||
def cic_read(mesh, positions, halo_size=0):
|
def cic_read(mesh, positions, halo_size=0):
|
||||||
return jaxpm.ops.cic_read(mesh, positions, halo_size=halo_size)
|
return pm_operators.cic_read(mesh, positions, halo_size=halo_size)
|
||||||
|
|
||||||
|
|
||||||
def cic_paint_dx(displacements, halo_size=0):
|
def cic_paint_dx(displacements, halo_size=0):
|
||||||
return jaxpm.ops.cic_paint_dx(displacements, halo_size=halo_size)
|
return pm_operators.cic_paint_dx(displacements, halo_size=halo_size)
|
||||||
|
|
||||||
|
|
||||||
|
def cic_read_dx(particle_mesh, halo_size=0):
|
||||||
|
return pm_operators.cic_read_dx(particle_mesh, halo_size=halo_size)
|
||||||
|
|
||||||
|
|
||||||
# TO REDO
|
|
||||||
def cic_paint_2d(mesh, positions, weight):
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
""" Paints positions onto a 2d mesh
|
""" Paints positions onto a 2d mesh
|
||||||
mesh: [nx, ny]
|
mesh: [nx, ny]
|
||||||
|
|
Loading…
Add table
Reference in a new issue