mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-15 10:21:11 +00:00
update by default normal_field dtype to match JAX
This commit is contained in:
parent
f391a79211
commit
e1a8134b8e
1 changed files with 1 additions and 1 deletions
|
@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
|
||||||
axis=-1)
|
axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def normal_field(seed , shape, sharding=None, dtype='float32'):
|
def normal_field(seed , shape, sharding=None, dtype=float):
|
||||||
"""Generate a Gaussian random field with the given power spectrum."""
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue