adding example of distributed solution

This commit is contained in:
EiffL 2024-07-09 17:45:28 -04:00
parent a2811c0606
commit a742065ffd
5 changed files with 192 additions and 62 deletions

View file

@ -3,13 +3,25 @@ import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import cic_compensation, fftk
from jax.sharding import PartitionSpec as P
from functools import partial
from jaxpm.distributed import autoshmap
def cic_paint(mesh, positions, weight=None):
@partial(autoshmap,
in_specs=(P('x', 'y'), P('x','y'), P('x','y')),
out_specs=P('x', 'y'))
def cic_paint(mesh, displacement, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
"""
part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid(
jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'), axis=-1) + displacement
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -34,11 +46,22 @@ def cic_paint(mesh, positions, weight=None):
return mesh
def cic_read(mesh, positions):
@partial(autoshmap,
in_specs=(P('x', 'y'), P('x','y')),
out_specs=P('x', 'y'))
def cic_read(mesh, displacement):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
displacement: [nx,ny,nz, 3]
"""
# Compute the position of the particles on a regular grid
part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid(
jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'), axis=-1) + displacement
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -52,7 +75,7 @@ def cic_read(mesh, positions):
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1])
def cic_paint_2d(mesh, positions, weight):