implement distributed optimized cic_paint

This commit is contained in:
Wassim KABALAN 2024-07-18 12:39:15 +02:00
parent e62cd84cbd
commit 5775a37550
2 changed files with 286 additions and 7 deletions

View file

@ -5,14 +5,13 @@ import jax.lax as lax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap
from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange,
slice_pad, slice_unpad)
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
@partial(autoshmap,
in_specs=(P('x', 'y'), P('x', 'y'), P('x', 'y')),
out_specs=P('x', 'y'))
def cic_paint(mesh, displacement, weight=None):
def cic_paint_impl(mesh, displacement, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
@ -48,8 +47,22 @@ def cic_paint(mesh, displacement, weight=None):
return mesh
@partial(autoshmap, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))
def cic_read(mesh, displacement):
@partial(jax.jit, static_argnums=(2, ))
def cic_paint(mesh, positions, halo_size=0, weight=None):
halo_size, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_size)
mesh = autoshmap(cic_paint_impl,
in_specs=(P('x', 'y'), P('x', 'y'), P()),
out_specs=P('x', 'y'))(mesh, positions, weight)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
mesh = slice_unpad(mesh, halo_size)
return mesh
def cic_read_impl(mesh, displacement):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement: [nx,ny,nz, 3]
@ -79,6 +92,21 @@ def cic_read(mesh, displacement):
displacement.shape[:-1])
@partial(jax.jit, static_argnums=(2, ))
def cic_read(mesh, displacement, halo_size=0):
halo_size, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_size)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
displacement = autoshmap(cic_read_impl,
in_specs=(P('x', 'y'), P('x', 'y')),
out_specs=P('x', 'y'))(mesh, displacement)
return displacement
def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
@ -108,6 +136,72 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def cic_paint_dx_impl(displacements, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
# Padding is forced to be zero in a single gpu run
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
particle_mesh = jnp.pad(particle_mesh, halo_size)
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh)
@partial(jax.jit, static_argnums=(1, ))
def cic_paint_dx(displacements, halo_size=0):
halo_size, halo_extents = get_halo_size(halo_size)
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(displacements)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
mesh = slice_unpad(mesh, halo_size)
return mesh
def cic_read_dx_impl(mesh):
original_shape = mesh.shape
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij')
pmid = jnp.stack([a, b, c], axis=-1)
pmid = pmid.reshape([-1, 3])
return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(1, ))
def cic_read_dx(mesh, halo_size=0):
halo_size, halo_extents = get_halo_size(halo_size)
mesh = slice_pad(mesh, halo_size)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True, True))
displacements = autoshmap(cic_read_dx_impl,
in_specs=(P('x', 'y')),
out_specs=P('x', 'y'))(mesh)
return displacements
def compensate_cic(field):
"""
Compensate for CiC painting