From 49c93aacf62cccca7e428c3417147cbed70e846a Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sun, 8 Jun 2025 10:45:20 +0200 Subject: [PATCH] Format --- jaxpm/distributed.py | 2 +- jaxpm/painting.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index feb9d39..3b5cbfc 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None): axis=-1) -def normal_field(mesh_shape, seed, sharding=None , dtype='float32'): +def normal_field(mesh_shape, seed, sharding=None, dtype='float32'): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty): diff --git a/jaxpm/painting.py b/jaxpm/painting.py index bd65784..13c5695 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -240,7 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size): def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None): halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) - halo_size = jax.tree.map(lambda x: x//2, halo_size) + # Halo size is halved for the read operation + # We only need to read the density field + # while in the painting operation we need to exchange and reduce the halo + # We chose to do that since it is much easier to write a custom jvp rule for exchange + # while it is a bit harder if there is a reduction involved + halo_size = jax.tree.map(lambda x: x // 2, halo_size) grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents,