mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-13 01:11:11 +00:00
make normal_field signature consistent with jax.random.normal
This commit is contained in:
parent
5807e1d3f4
commit
f391a79211
2 changed files with 4 additions and 4 deletions
|
@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None):
|
|||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None, dtype='float32'):
|
||||
def normal_field(seed , shape, sharding=None, dtype='float32'):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
local_mesh_shape = get_local_shape(shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
# rank = jax.process_index()
|
||||
|
@ -195,4 +195,4 @@ def normal_field(mesh_shape, seed, sharding=None, dtype='float32'):
|
|||
in_specs=P(None),
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed, dtype=dtype)
|
||||
return jax.random.normal(shape=shape, key=seed, dtype=dtype)
|
||||
|
|
|
@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
|||
Generate initial conditions.
|
||||
"""
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
|
||||
field = normal_field(seed=seed , shape=mesh_shape, sharding=sharding)
|
||||
field = fft3d(field)
|
||||
kvec = fftk(field)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue