This commit is contained in:
Wassim KABALAN 2024-04-19 10:32:38 +02:00
parent 055ceedb7e
commit 179030377b
4 changed files with 63 additions and 38 deletions

View file

@ -1,4 +1,5 @@
import jax
from jax import jit
import jax.numpy as jnp
import jax.lax as lax
@ -22,7 +23,6 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
print(f" positions {positions.shape}")
if sharding_info is not None:
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
@ -34,20 +34,25 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
# Add some padding for the halo exchange
with gpu_mesh:
nbody_mesh = sharded_pad(nbody_mesh)
positions = add_halo(positions , halo_size)
with gpu_mesh:
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
floor = jit(jnp.floor)(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
@jit
def compute_kernels(positions , neighboor_coords):
kernel = (1. - jnp.abs(positions - neighboor_coords))
return (kernel[..., 0] * kernel[..., 1] * kernel[..., 2])
with gpu_mesh:
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jit(jnp.add)(floor , connection)
kernel = compute_kernels(positions , neighboor_coords)
neighboor_coords = jnp.mod(neighboor_coords.reshape(
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
@ -66,9 +71,7 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
if sharding_info == None:
return nbody_mesh
else:
with gpu_mesh :
nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
nbody_mesh = halo_reduce(nbody_mesh, sharding_info.halo_extents[0] , gpu_mesh)
return nbody_mesh