update by default normal_field dtype to match JAX

This commit is contained in:
Wassim Kabalan 2025-06-12 14:46:27 +02:00
parent f391a79211
commit e1a8134b8e

View file

@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1)
def normal_field(seed , shape, sharding=None, dtype='float32'):
def normal_field(seed , shape, sharding=None, dtype=float):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):