mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
implement distributed optimized cic_paint
This commit is contained in:
parent
e62cd84cbd
commit
5775a37550
2 changed files with 286 additions and 7 deletions
|
@ -5,14 +5,13 @@ import jax.lax as lax
|
|||
import jax.numpy as jnp
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import autoshmap
|
||||
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
|
||||
slice_pad, slice_unpad)
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x', 'y'), P('x', 'y'), P('x', 'y')),
|
||||
out_specs=P('x', 'y'))
|
||||
def cic_paint(mesh, displacement, weight=None):
|
||||
def cic_paint_impl(mesh, displacement, weight=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
|
@ -48,8 +47,22 @@ def cic_paint(mesh, displacement, weight=None):
|
|||
return mesh
|
||||
|
||||
|
||||
@partial(autoshmap, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))
|
||||
def cic_read(mesh, displacement):
|
||||
@partial(jax.jit, static_argnums=(2, ))
|
||||
def cic_paint(mesh, positions, halo_size=0, weight=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
mesh = autoshmap(cic_paint_impl,
|
||||
in_specs=(P('x', 'y'), P('x', 'y'), P()),
|
||||
out_specs=P('x', 'y'))(mesh, positions, weight)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
mesh = slice_unpad(mesh, halo_size)
|
||||
return mesh
|
||||
|
||||
|
||||
def cic_read_impl(mesh, displacement):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement: [nx,ny,nz, 3]
|
||||
|
@ -79,6 +92,21 @@ def cic_read(mesh, displacement):
|
|||
displacement.shape[:-1])
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, ))
|
||||
def cic_read(mesh, displacement, halo_size=0):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
displacement = autoshmap(cic_read_impl,
|
||||
in_specs=(P('x', 'y'), P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(mesh, displacement)
|
||||
|
||||
return displacement
|
||||
|
||||
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
""" Paints positions onto a 2d mesh
|
||||
mesh: [nx, ny]
|
||||
|
@ -108,6 +136,72 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def cic_paint_dx_impl(displacements, halo_size):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
|
||||
# Padding is forced to be zero in a single gpu run
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
||||
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, ))
|
||||
def cic_paint_dx(displacements, halo_size=0):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
|
||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(displacements)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
mesh = slice_unpad(mesh, halo_size)
|
||||
return mesh
|
||||
|
||||
|
||||
def cic_read_dx_impl(mesh):
|
||||
|
||||
original_shape = mesh.shape
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||
jnp.arange(original_shape[1]),
|
||||
jnp.arange(original_shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([a, b, c], axis=-1)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
|
||||
return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, ))
|
||||
def cic_read_dx(mesh, halo_size=0):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
displacements = autoshmap(cic_read_dx_impl,
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(mesh)
|
||||
return displacements
|
||||
|
||||
|
||||
def compensate_cic(field):
|
||||
"""
|
||||
Compensate for CiC painting
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue