mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 09:37:11 +00:00
347 KiB
347 KiB
In [ ]:
# Installing JaxPM
!pip install --quiet --upgrade jax
!pip install --quiet jaxpm
In [2]:
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.ode import odeint
from jaxpm.painting import cic_paint
from jaxpm.pm import linear_field, lpt, make_ode_fn
In [6]:
mesh_shape = [128, 128, 128]
box_size = [128., 128., 128.]
snapshots = jnp.array([0.1, 0.5, 1.0])
@jax.jit
def run_simulation(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)
# Create initial conditions and particles
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])
# Initial displacement
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, f = lpt(cosmo, initial_conditions, particles, a=0.1)
# Evolve the simulation forward
res = odeint(make_ode_fn(mesh_shape), [particles + dx, p], snapshots, cosmo, rtol=1e-8, atol=1e-8)
# Return the simulation volume at requested timesteps
return initial_conditions , particles + dx , res[0]
In [ ]:
initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)
In [20]:
from jaxpm.plotting import plot_fields_single_projection
fields = {"Initial Conditions" : initial_conditions , "LPT Field" : jnp.log10(cic_paint(jnp.zeros(mesh_shape) ,lpt_particles) + 1)}
for i , field in enumerate(ode_particles[1:]):
fields[f"field_{i}"] = jnp.log10(cic_paint(jnp.zeros(mesh_shape) , field)+1)
plot_fields_single_projection(fields)
In [ ]: