diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 83d5cb9..20d49ca 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -131,3 +131,22 @@ def get_local_shape(mesh_shape): return [ mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2] ] + + + +def normal_field(mesh_shape, seed=None): + """Generate a Gaussian random field with the given power spectrum.""" + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): + local_mesh_shape = get_local_shape(mesh_shape) + if seed is None: + key = None + else: + size = jax.process_count() + rank = jax.process_index() + key = jax.random.split(seed, size)[rank] + return autoshmap( + partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'), + in_specs=P(None), + out_specs=P('x', 'y'))(key) # yapf: disable + else: + return jax.random.normal(shape=mesh_shape, key=seed) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 20c251e..7bff37d 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -5,7 +5,8 @@ import jax.numpy as jnp import jax_cosmo as jc from jax.sharding import PartitionSpec as P -from jaxpm.distributed import autoshmap, fft3d, get_local_shape, ifft3d +from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d, + normal_field) from jaxpm.growth import dGfa, growth_factor, growth_rate from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, longrange_kernel) @@ -71,12 +72,7 @@ def linear_field(mesh_shape, box_size, pk, seed): box_size[0] * box_size[1] * box_size[2]) # Initialize a random field with one slice on each gpu - local_mesh_shape = get_local_shape(mesh_shape) - field = autoshmap( - partial(jax.random.normal, shape=local_mesh_shape, dtype='float32'), - in_specs=P(None), - out_specs=P('x', 'y'))(seed) # yapf: disable - + field = normal_field(mesh_shape, seed=seed) field = fft3d(field) * pkmesh**0.5 field = ifft3d(field) return field