From 079cebbdeae123e7ce349c7f5ed2943942865b57 Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 25 Mar 2022 22:34:13 +0100 Subject: [PATCH] fix normalization of init cond --- jaxpm/pm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index f4b405e..3f39c9c 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -44,7 +44,7 @@ def linear_field(mesh_shape, box_size, pk, seed): """ kvec = fftk(mesh_shape) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 - pkmesh = pk(kmesh) + pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) field = jax.random.normal(seed, mesh_shape) field = jnp.fft.rfftn(field) * pkmesh**0.5