mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 00:51:11 +00:00
Adding an example of jaxdecomp implementation
This commit is contained in:
parent
6644b35d71
commit
6ca4c9191e
5 changed files with 166 additions and 192 deletions
|
@ -6,12 +6,12 @@ from jaxpm.ops import halo_reduce
|
|||
from jaxpm.kernels import fftk, cic_compensation
|
||||
|
||||
|
||||
def cic_paint(mesh, positions, halo_size=0, comms=None):
|
||||
def cic_paint(mesh, positions, halo_size=0, sharding_info=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
if comms is not None:
|
||||
if sharding_info is not None:
|
||||
# Add some padding for the halo exchange
|
||||
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
||||
[halo_size, halo_size],
|
||||
|
@ -40,26 +40,32 @@ def cic_paint(mesh, positions, halo_size=0, comms=None):
|
|||
kernel.reshape([-1, 8]),
|
||||
dnums)
|
||||
|
||||
if comms == None:
|
||||
if sharding_info == None:
|
||||
return mesh
|
||||
else:
|
||||
mesh = halo_reduce(mesh, halo_size, comms)
|
||||
mesh = halo_reduce(mesh, sharding_info)
|
||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||
|
||||
|
||||
def cic_read(mesh, positions, halo_size=0, comms=None):
|
||||
def cic_read(mesh, positions, halo_size=0, sharding_info=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
|
||||
if comms is not None:
|
||||
if sharding_info is not None:
|
||||
# Add some padding and perfom hao exchange to retrieve
|
||||
# neighboring regions
|
||||
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
||||
[halo_size, halo_size],
|
||||
[0, 0]])
|
||||
mesh = halo_reduce(mesh, halo_size, comms)
|
||||
# mesh = halo_reduce(mesh, sharding_info)
|
||||
import jaxdecomp
|
||||
mesh = jaxdecomp.halo_exchange(mesh,
|
||||
halo_extents=sharding_info.halo_extents,
|
||||
halo_periods=(True,True,True),
|
||||
pdims=sharding_info.pdims,
|
||||
global_shape=sharding_info.global_shape)
|
||||
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
||||
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue