update cic read halo size and notebooks examples

This commit is contained in:
Wassim Kabalan 2025-06-07 19:26:37 +02:00
parent d4049e5db4
commit e7112e0c25
5 changed files with 161 additions and 176 deletions

View file

@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1) axis=-1)
def normal_field(mesh_shape, seed, sharding=None): def normal_field(mesh_shape, seed, sharding=None , dtype='float32'):
"""Generate a Gaussian random field with the given power spectrum.""" """Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty): if gpu_mesh is not None and not (gpu_mesh.empty):
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype) return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map( return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'), partial(normal, shape=local_mesh_shape, dtype=dtype),
mesh=gpu_mesh, mesh=gpu_mesh,
in_specs=P(None), in_specs=P(None),
out_specs=spec)(keys) # yapf: disable out_specs=spec)(keys) # yapf: disable
else: else:
return jax.random.normal(shape=mesh_shape, key=seed) return jax.random.normal(shape=mesh_shape, key=seed, dtype=dtype)

View file

@ -167,7 +167,7 @@ def _cic_paint_dx_impl(displacements,
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
original_shape = displacements.shape original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') particle_mesh = jnp.zeros(original_shape[:-1], dtype=displacements.dtype)
if not jnp.isscalar(weight): if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]: if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape") raise ValueError("Weight shape must match particle shape")
@ -185,7 +185,7 @@ def _cic_paint_dx_impl(displacements,
return scatter(pmid.reshape([-1, 3]), return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]), displacements.reshape([-1, 3]),
particle_mesh, particle_mesh,
chunk_size=2**24, chunk_size=chunk_size,
val=weight) val=weight)
@ -240,6 +240,7 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None): def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
halo_size = jax.tree.map(lambda x: x//2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long