diff --git a/jaxpm/pm.py b/jaxpm/pm.py index f4b405e..3f39c9c 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -44,7 +44,7 @@ def linear_field(mesh_shape, box_size, pk, seed): """ kvec = fftk(mesh_shape) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 - pkmesh = pk(kmesh) + pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) field = jax.random.normal(seed, mesh_shape) field = jnp.fft.rfftn(field) * pkmesh**0.5