mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
Adding fixamp ics
This commit is contained in:
parent
101ef80d9c
commit
a67fc42495
1 changed files with 35 additions and 0 deletions
35
jaxpm/pm.py
35
jaxpm/pm.py
|
@ -51,6 +51,41 @@ def linear_field(mesh_shape, box_size, pk, seed):
|
|||
field = jnp.fft.irfftn(field)
|
||||
return field
|
||||
|
||||
def box_muller_field(amplitude, phase, pkmesh):
|
||||
"""
|
||||
Obtain Gaussian random field given uniform random numbers and Pk amplitude.
|
||||
"""
|
||||
field = pkmesh**0.5 * jnp.sqrt(-jnp.log(amplitude)) * (jnp.cos(phase) + 1j * jnp.sin(phase))
|
||||
return jnp.fft.irfftn(field, (amplitude.shape[0],)*3, norm='ortho')
|
||||
|
||||
def linear_field_box_muller(mesh_shape, box_size, pk, seed, fixamp = False, inv_phase = False):
|
||||
"""
|
||||
Generate initial conditions with fixed amplitude and/or inverted phase.
|
||||
"""
|
||||
|
||||
key, subkey1, subkey2 = jax.random.split(seed, 3)
|
||||
kvec = fftk(mesh_shape)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
if fixamp:
|
||||
amplitude = jnp.ones_like(kmesh)
|
||||
else:
|
||||
amplitude = jax.random.uniform(subkey1, kmesh.shape, minval=1e-8)
|
||||
|
||||
|
||||
if inv_phase:
|
||||
phase = jax.random.uniform(subkey2, kmesh.shape, minval=1e-8) * 2 * jnp.pi
|
||||
ret = []
|
||||
ret.append(box_muller_field(amplitude, phase, pkmesh))
|
||||
phase = (phase + jnp.pi)
|
||||
ret.append(box_muller_field(amplitude, phase, pkmesh))
|
||||
return ret
|
||||
|
||||
return box_muller_field(amplitude, phase, pkmesh)
|
||||
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
|
|
Loading…
Add table
Reference in a new issue