mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 00:51:11 +00:00
pm ok
This commit is contained in:
parent
055ceedb7e
commit
179030377b
4 changed files with 63 additions and 38 deletions
|
@ -1,4 +1,5 @@
|
|||
import jax
|
||||
from jax import jit
|
||||
import jax.numpy as jnp
|
||||
import jax.lax as lax
|
||||
|
||||
|
@ -22,7 +23,6 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
|||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
print(f" positions {positions.shape}")
|
||||
if sharding_info is not None:
|
||||
|
||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
|
||||
|
@ -34,20 +34,25 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
|||
# Add some padding for the halo exchange
|
||||
with gpu_mesh:
|
||||
nbody_mesh = sharded_pad(nbody_mesh)
|
||||
|
||||
positions = add_halo(positions , halo_size)
|
||||
|
||||
with gpu_mesh:
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
floor = jit(jnp.floor)(positions)
|
||||
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||
[0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
@jit
|
||||
def compute_kernels(positions , neighboor_coords):
|
||||
kernel = (1. - jnp.abs(positions - neighboor_coords))
|
||||
return (kernel[..., 0] * kernel[..., 1] * kernel[..., 2])
|
||||
|
||||
with gpu_mesh:
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
neighboor_coords = jit(jnp.add)(floor , connection)
|
||||
kernel = compute_kernels(positions , neighboor_coords)
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.reshape(
|
||||
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
|
||||
|
@ -66,9 +71,7 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
|||
if sharding_info == None:
|
||||
return nbody_mesh
|
||||
else:
|
||||
with gpu_mesh :
|
||||
nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
|
||||
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||
nbody_mesh = halo_reduce(nbody_mesh, sharding_info.halo_extents[0] , gpu_mesh)
|
||||
|
||||
return nbody_mesh
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue