fix normalization of init cond

This commit is contained in:
EiffL 2022-03-25 22:34:13 +01:00
parent 607acf2e4f
commit 079cebbdea

View file

@ -44,7 +44,7 @@ def linear_field(mesh_shape, box_size, pk, seed):
"""
kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh)
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5