This commit is contained in:
Wassim Kabalan 2025-06-08 10:45:20 +02:00
parent 41ae41ace3
commit 49c93aacf6
2 changed files with 7 additions and 2 deletions

View file

@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1)
def normal_field(mesh_shape, seed, sharding=None , dtype='float32'):
def normal_field(mesh_shape, seed, sharding=None, dtype='float32'):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):

View file

@ -240,7 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
halo_size = jax.tree.map(lambda x: x//2, halo_size)
# Halo size is halved for the read operation
# We only need to read the density field
# while in the painting operation we need to exchange and reduce the halo
# We chose to do that since it is much easier to write a custom jvp rule for exchange
# while it is a bit harder if there is a reduction involved
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,