From f391a7921182fe963cef77ab524e6d7a3afe44eb Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Mon, 9 Jun 2025 19:35:40 +0200 Subject: [PATCH] make normal_field signature consistent with jax.random.normal --- jaxpm/distributed.py | 6 +++--- jaxpm/pm.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 3b5cbfc..86cd816 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -166,11 +166,11 @@ def uniform_particles(mesh_shape, sharding=None): axis=-1) -def normal_field(mesh_shape, seed, sharding=None, dtype='float32'): +def normal_field(seed , shape, sharding=None, dtype='float32'): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is not None and not (gpu_mesh.empty): - local_mesh_shape = get_local_shape(mesh_shape, sharding) + local_mesh_shape = get_local_shape(shape, sharding) size = jax.device_count() # rank = jax.process_index() @@ -195,4 +195,4 @@ def normal_field(mesh_shape, seed, sharding=None, dtype='float32'): in_specs=P(None), out_specs=spec)(keys) # yapf: disable else: - return jax.random.normal(shape=mesh_shape, key=seed, dtype=dtype) + return jax.random.normal(shape=shape, key=seed, dtype=dtype) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 9951e1c..ae7db3f 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): Generate initial conditions. """ # Initialize a random field with one slice on each gpu - field = normal_field(mesh_shape, seed=seed, sharding=sharding) + field = normal_field(seed=seed , shape=mesh_shape, sharding=sharding) field = fft3d(field) kvec = fftk(field) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2