update formatting

This commit is contained in:
EiffL 2024-07-09 18:02:57 -04:00
parent 6408aff1de
commit 319942a6bc
5 changed files with 113 additions and 96 deletions

View file

@ -1,26 +1,28 @@
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
from jaxpm.kernels import cic_compensation, fftk
from jax.sharding import PartitionSpec as P
from functools import partial
from jaxpm.distributed import autoshmap
@partial(autoshmap,
in_specs=(P('x', 'y'), P('x','y'), P('x','y')),
out_specs=P('x', 'y'))
from jaxpm.distributed import autoshmap
from jaxpm.kernels import cic_compensation, fftk
@partial(autoshmap,
in_specs=(P('x', 'y'), P('x', 'y'), P('x', 'y')),
out_specs=P('x', 'y'))
def cic_paint(mesh, displacement, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
"""
part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid(
jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'), axis=-1) + displacement
positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'),
axis=-1) + displacement
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
@ -46,9 +48,7 @@ def cic_paint(mesh, displacement, weight=None):
return mesh
@partial(autoshmap,
in_specs=(P('x', 'y'), P('x','y')),
out_specs=P('x', 'y'))
@partial(autoshmap, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))
def cic_read(mesh, displacement):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
@ -56,11 +56,11 @@ def cic_read(mesh, displacement):
"""
# Compute the position of the particles on a regular grid
part_shape = displacement.shape
positions = jnp.stack(jnp.meshgrid(
jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'), axis=-1) + displacement
positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
jnp.arange(part_shape[1]),
jnp.arange(part_shape[2]),
indexing='ij'),
axis=-1) + displacement
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
@ -75,7 +75,8 @@ def cic_read(mesh, displacement):
jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1])
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(
displacement.shape[:-1])
def cic_paint_2d(mesh, positions, weight):