diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 86cd816..2d361f6 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None): axis=-1) -def normal_field(seed , shape, sharding=None, dtype='float32'): +def normal_field(seed , shape, sharding=None, dtype=float): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty):