Adding an example of jaxdecomp implementation

This commit is contained in:
EiffL 2022-11-26 17:27:14 +01:00
parent 6644b35d71
commit 6ca4c9191e
5 changed files with 166 additions and 192 deletions

View file

@ -6,12 +6,12 @@ from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, halo_size=0, comms=None):
def cic_paint(mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
if comms is not None:
if sharding_info is not None:
# Add some padding for the halo exchange
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
@ -40,26 +40,32 @@ def cic_paint(mesh, positions, halo_size=0, comms=None):
kernel.reshape([-1, 8]),
dnums)
if comms == None:
if sharding_info == None:
return mesh
else:
mesh = halo_reduce(mesh, halo_size, comms)
mesh = halo_reduce(mesh, sharding_info)
return mesh[halo_size:-halo_size, halo_size:-halo_size]
def cic_read(mesh, positions, halo_size=0, comms=None):
def cic_read(mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
if comms is not None:
if sharding_info is not None:
# Add some padding and perfom hao exchange to retrieve
# neighboring regions
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
mesh = halo_reduce(mesh, halo_size, comms)
# mesh = halo_reduce(mesh, sharding_info)
import jaxdecomp
mesh = jaxdecomp.halo_exchange(mesh,
halo_extents=sharding_info.halo_extents,
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)