update cic read halo size and notebooks examples

This commit is contained in:
Wassim Kabalan 2025-06-07 19:26:37 +02:00
parent d4049e5db4
commit e7112e0c25
5 changed files with 161 additions and 176 deletions

View file

@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1)
def normal_field(mesh_shape, seed, sharding=None):
def normal_field(mesh_shape, seed, sharding=None , dtype='float32'):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
@ -190,9 +190,9 @@ def normal_field(mesh_shape, seed, sharding=None):
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'),
partial(normal, shape=local_mesh_shape, dtype=dtype),
mesh=gpu_mesh,
in_specs=P(None),
out_specs=spec)(keys) # yapf: disable
else:
return jax.random.normal(shape=mesh_shape, key=seed)
return jax.random.normal(shape=mesh_shape, key=seed, dtype=dtype)