mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
Fix seed for distributed normal
This commit is contained in:
parent
7501b5bc6d
commit
7f48cfa8af
2 changed files with 22 additions and 7 deletions
|
@ -131,3 +131,22 @@ def get_local_shape(mesh_shape):
|
||||||
return [
|
return [
|
||||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
|
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def normal_field(mesh_shape, seed=None):
|
||||||
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
|
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||||
|
local_mesh_shape = get_local_shape(mesh_shape)
|
||||||
|
if seed is None:
|
||||||
|
key = None
|
||||||
|
else:
|
||||||
|
size = jax.process_count()
|
||||||
|
rank = jax.process_index()
|
||||||
|
key = jax.random.split(seed, size)[rank]
|
||||||
|
return autoshmap(
|
||||||
|
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
|
||||||
|
in_specs=P(None),
|
||||||
|
out_specs=P('x', 'y'))(key) # yapf: disable
|
||||||
|
else:
|
||||||
|
return jax.random.normal(shape=mesh_shape, key=seed)
|
||||||
|
|
10
jaxpm/pm.py
10
jaxpm/pm.py
|
@ -5,7 +5,8 @@ import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
from jaxpm.distributed import autoshmap, fft3d, get_local_shape, ifft3d
|
from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
|
||||||
|
normal_field)
|
||||||
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
||||||
longrange_kernel)
|
longrange_kernel)
|
||||||
|
@ -71,12 +72,7 @@ def linear_field(mesh_shape, box_size, pk, seed):
|
||||||
box_size[0] * box_size[1] * box_size[2])
|
box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
# Initialize a random field with one slice on each gpu
|
# Initialize a random field with one slice on each gpu
|
||||||
local_mesh_shape = get_local_shape(mesh_shape)
|
field = normal_field(mesh_shape, seed=seed)
|
||||||
field = autoshmap(
|
|
||||||
partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'),
|
|
||||||
in_specs=P(None),
|
|
||||||
out_specs=P('x', 'y'))(seed) # yapf: disable
|
|
||||||
|
|
||||||
field = fft3d(field) * pkmesh**0.5
|
field = fft3d(field) * pkmesh**0.5
|
||||||
field = ifft3d(field)
|
field = ifft3d(field)
|
||||||
return field
|
return field
|
||||||
|
|
Loading…
Add table
Reference in a new issue