update cic read halo size and notebooks examples

This commit is contained in:
Wassim Kabalan 2025-06-07 19:26:37 +02:00
parent d4049e5db4
commit e7112e0c25
5 changed files with 161 additions and 176 deletions

View file

@ -167,7 +167,7 @@ def _cic_paint_dx_impl(displacements,
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
@ -185,7 +185,7 @@ def _cic_paint_dx_impl(displacements,
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
chunk_size=2**24,
chunk_size=chunk_size,
val=weight)
@ -240,6 +240,7 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
halo_size = jax.tree.map(lambda x: x//2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,