mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
Adding fixamp ics
This commit is contained in:
parent
101ef80d9c
commit
a67fc42495
1 changed files with 35 additions and 0 deletions
35
jaxpm/pm.py
35
jaxpm/pm.py
|
@ -51,6 +51,41 @@ def linear_field(mesh_shape, box_size, pk, seed):
|
||||||
field = jnp.fft.irfftn(field)
|
field = jnp.fft.irfftn(field)
|
||||||
return field
|
return field
|
||||||
|
|
||||||
|
def box_muller_field(amplitude, phase, pkmesh):
|
||||||
|
"""
|
||||||
|
Obtain Gaussian random field given uniform random numbers and Pk amplitude.
|
||||||
|
"""
|
||||||
|
field = pkmesh**0.5 * jnp.sqrt(-jnp.log(amplitude)) * (jnp.cos(phase) + 1j * jnp.sin(phase))
|
||||||
|
return jnp.fft.irfftn(field, (amplitude.shape[0],)*3, norm='ortho')
|
||||||
|
|
||||||
|
def linear_field_box_muller(mesh_shape, box_size, pk, seed, fixamp = False, inv_phase = False):
|
||||||
|
"""
|
||||||
|
Generate initial conditions with fixed amplitude and/or inverted phase.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key, subkey1, subkey2 = jax.random.split(seed, 3)
|
||||||
|
kvec = fftk(mesh_shape)
|
||||||
|
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
|
||||||
|
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
|
if fixamp:
|
||||||
|
amplitude = jnp.ones_like(kmesh)
|
||||||
|
else:
|
||||||
|
amplitude = jax.random.uniform(subkey1, kmesh.shape, minval=1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
if inv_phase:
|
||||||
|
phase = jax.random.uniform(subkey2, kmesh.shape, minval=1e-8) * 2 * jnp.pi
|
||||||
|
ret = []
|
||||||
|
ret.append(box_muller_field(amplitude, phase, pkmesh))
|
||||||
|
phase = (phase + jnp.pi)
|
||||||
|
ret.append(box_muller_field(amplitude, phase, pkmesh))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return box_muller_field(amplitude, phase, pkmesh)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def make_ode_fn(mesh_shape):
|
def make_ode_fn(mesh_shape):
|
||||||
|
|
||||||
def nbody_ode(state, a, cosmo):
|
def nbody_ode(state, a, cosmo):
|
||||||
|
|
Loading…
Add table
Reference in a new issue