update cic read halo size and notebooks examples

This commit is contained in:
Wassim Kabalan 2025-06-07 19:26:37 +02:00
parent d4049e5db4
commit e7112e0c25
5 changed files with 161 additions and 176 deletions

View file

@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1)
def normal_field(mesh_shape, seed, sharding=None):
def normal_field(mesh_shape, seed, sharding=None , dtype='float32'):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'),
partial(normal, shape=local_mesh_shape, dtype=dtype),
mesh=gpu_mesh,
in_specs=P(None),
out_specs=spec)(keys) # yapf: disable
else:
return jax.random.normal(shape=mesh_shape, key=seed)
return jax.random.normal(shape=mesh_shape, key=seed, dtype=dtype)

View file

@ -167,7 +167,7 @@ def _cic_paint_dx_impl(displacements,
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
@ -185,7 +185,7 @@ def _cic_paint_dx_impl(displacements,
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
chunk_size=2**24,
chunk_size=chunk_size,
val=weight)
@ -240,6 +240,7 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
halo_size = jax.tree.map(lambda x: x//2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long