make normal_field signature consistent with jax.random.normal

This commit is contained in:
Wassim Kabalan 2025-06-09 19:35:40 +02:00
parent 5807e1d3f4
commit f391a79211
2 changed files with 4 additions and 4 deletions

View file

@ -131,7 +131,7 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
Generate initial conditions.
"""
# Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
field = normal_field(seed=seed , shape=mesh_shape, sharding=sharding)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2