mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
update formatting
This commit is contained in:
parent
6408aff1de
commit
319942a6bc
5 changed files with 113 additions and 96 deletions
|
@ -1,26 +1,28 @@
|
|||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from functools import partial
|
||||
from jaxpm.distributed import autoshmap
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x', 'y'), P('x','y'), P('x','y')),
|
||||
out_specs=P('x', 'y'))
|
||||
from jaxpm.distributed import autoshmap
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x', 'y'), P('x', 'y'), P('x', 'y')),
|
||||
out_specs=P('x', 'y'))
|
||||
def cic_paint(mesh, displacement, weight=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
"""
|
||||
part_shape = displacement.shape
|
||||
positions = jnp.stack(jnp.meshgrid(
|
||||
jnp.arange(part_shape[0]),
|
||||
jnp.arange(part_shape[1]),
|
||||
jnp.arange(part_shape[2]),
|
||||
indexing='ij'), axis=-1) + displacement
|
||||
positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
|
||||
jnp.arange(part_shape[1]),
|
||||
jnp.arange(part_shape[2]),
|
||||
indexing='ij'),
|
||||
axis=-1) + displacement
|
||||
positions = positions.reshape([-1, 3])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
|
@ -46,9 +48,7 @@ def cic_paint(mesh, displacement, weight=None):
|
|||
return mesh
|
||||
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x', 'y'), P('x','y')),
|
||||
out_specs=P('x', 'y'))
|
||||
@partial(autoshmap, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))
|
||||
def cic_read(mesh, displacement):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
|
@ -56,11 +56,11 @@ def cic_read(mesh, displacement):
|
|||
"""
|
||||
# Compute the position of the particles on a regular grid
|
||||
part_shape = displacement.shape
|
||||
positions = jnp.stack(jnp.meshgrid(
|
||||
jnp.arange(part_shape[0]),
|
||||
jnp.arange(part_shape[1]),
|
||||
jnp.arange(part_shape[2]),
|
||||
indexing='ij'), axis=-1) + displacement
|
||||
positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
|
||||
jnp.arange(part_shape[1]),
|
||||
jnp.arange(part_shape[2]),
|
||||
indexing='ij'),
|
||||
axis=-1) + displacement
|
||||
positions = positions.reshape([-1, 3])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
|
@ -75,7 +75,8 @@ def cic_read(mesh, displacement):
|
|||
jnp.array(mesh.shape))
|
||||
|
||||
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(displacement.shape[:-1])
|
||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(
|
||||
displacement.shape[:-1])
|
||||
|
||||
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue