painting.py no longer uses global mesh

This commit is contained in:
Wassim KABALAN 2024-10-22 12:10:01 -04:00
parent 80c56dced5
commit 105568e8db

View file

@ -11,7 +11,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter from jaxpm.painting_utils import gather, scatter
def cic_paint_impl(mesh, displacement, weight=None): def cic_paint_impl(grid_mesh, displacement, weight=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3] displacement field: [nx, ny, nz, 3]
@ -36,30 +36,34 @@ def cic_paint_impl(mesh, displacement, weight=None):
neighboor_coords = jnp.mod( neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'), neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape)) jnp.array(grid_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2), inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, scatter_dims_to_operand_dims=(0, 1,
2)) 2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]), mesh = lax.scatter_add(grid_mesh, neighboor_coords,
dnums) kernel.reshape([-1, 8]), dnums)
return mesh return mesh
@partial(jax.jit, static_argnums=(2, )) @partial(jax.jit, static_argnums=(2, 3, 4))
def cic_paint(mesh, positions, halo_size=0, weight=None): def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size) halo_size, halo_extents = get_halo_size(halo_size, sharding)
mesh = slice_pad(mesh, halo_size) grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
mesh = autoshmap(cic_paint_impl,
in_specs=(P('x', 'y'), P('x', 'y'), P()), gpu_mesh = sharding.mesh if sharding is not None else None
out_specs=P('x', 'y'))(mesh, positions, weight) spec = sharding.spec if sharding is not None else P()
mesh = halo_exchange(mesh, grid_mesh = autoshmap(cic_paint_impl,
halo_extents=halo_extents, gpu_mesh=gpu_mesh,
halo_periods=(True, True)) in_specs=(spec, spec, P()),
mesh = slice_unpad(mesh, halo_size) out_specs=spec)(grid_mesh, positions, weight)
return mesh grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
return grid_mesh
def cic_read_impl(mesh, displacement): def cic_read_impl(mesh, displacement):
@ -92,27 +96,30 @@ def cic_read_impl(mesh, displacement):
displacement.shape[:-1]) displacement.shape[:-1])
@partial(jax.jit, static_argnums=(2, )) @partial(jax.jit, static_argnums=(2, 3))
def cic_read(mesh, displacement, halo_size=0): def cic_read(mesh, displacement, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size) halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
mesh = slice_pad(mesh, halo_size) mesh = slice_pad(mesh, halo_size, sharding=sharding)
mesh = halo_exchange(mesh, mesh = halo_exchange(mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True)) halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
displacement = autoshmap(cic_read_impl, displacement = autoshmap(cic_read_impl,
in_specs=(P('x', 'y'), P('x', 'y')), gpu_mesh=gpu_mesh,
out_specs=P('x', 'y'))(mesh, displacement) in_specs=(spec, spec),
out_specs=spec)(mesh, displacement)
return displacement return displacement
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh """ Paints positions onto a 2d mesh
mesh: [nx, ny] mesh: [nx, ny]
positions: [npart, 2] positions: [npart, 2]
weight: [npart] weight: [npart]
""" """
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -157,29 +164,32 @@ def cic_paint_dx_impl(displacements, halo_size):
return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh) return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh)
@partial(jax.jit, static_argnums=(1, )) @partial(jax.jit, static_argnums=(1, 2))
def cic_paint_dx(displacements, halo_size=0): def cic_paint_dx(displacements, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size) halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), gpu_mesh = sharding.mesh if sharding is not None else None
in_specs=(P('x', 'y')), spec = sharding.spec if sharding is not None else P()
out_specs=P('x', 'y'))(displacements) grid_mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(displacements)
mesh = halo_exchange(mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True)) halo_periods=(True, True))
mesh = slice_unpad(mesh, halo_size) grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
return mesh return grid_mesh
def cic_read_dx_impl(mesh, halo_size): def cic_read_dx_impl(grid_mesh, halo_size):
halo_x, _ = halo_size[0] halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
original_shape = [ original_shape = [
dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size) dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
] ]
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]), a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]), jnp.arange(original_shape[1]),
@ -190,32 +200,36 @@ def cic_read_dx_impl(mesh, halo_size):
pmid = pmid.reshape([-1, 3]) pmid = pmid.reshape([-1, 3])
return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape) return gather(pmid, jnp.zeros_like(pmid),
grid_mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(1, )) @partial(jax.jit, static_argnums=(1, 2))
def cic_read_dx(mesh, halo_size=0): def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
# return mesh # return mesh
halo_size, halo_extents = get_halo_size(halo_size) halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
mesh = slice_pad(mesh, halo_size) grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
mesh = halo_exchange(mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True)) halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
in_specs=(P('x', 'y')), gpu_mesh=gpu_mesh,
out_specs=P('x', 'y'))(mesh) in_specs=(spec),
out_specs=spec)(grid_mesh)
return displacements return displacements
def compensate_cic(field): def compensate_cic(field):
""" """
Compensate for CiC painting Compensate for CiC painting
Args: Args:
field: input 3D cic-painted field field: input 3D cic-painted field
Returns: Returns:
compensated_field compensated_field
""" """
nc = field.shape nc = field.shape
kvec = fftk(nc) kvec = fftk(nc)