mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Fix seed for distributed normal
This commit is contained in:
parent
7501b5bc6d
commit
7f48cfa8af
2 changed files with 22 additions and 7 deletions
|
@ -131,3 +131,22 @@ def get_local_shape(mesh_shape):
|
|||
return [
|
||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
|
||||
]
|
||||
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed=None):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape)
|
||||
if seed is None:
|
||||
key = None
|
||||
else:
|
||||
size = jax.process_count()
|
||||
rank = jax.process_index()
|
||||
key = jax.random.split(seed, size)[rank]
|
||||
return autoshmap(
|
||||
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
|
||||
in_specs=P(None),
|
||||
out_specs=P('x', 'y'))(key) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue