mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
7.9 KiB
7.9 KiB
In [8]:
!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git
!pip install diffrax
In [1]:
import os
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax.experimental.ode import odeint
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint
from jaxpm.pm import linear_field, lpt, make_ode_fn
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
In [ ]:
assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
In [ ]:
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from functools import partial
all_gather = partial(process_allgather, tiled=True)
pdims = (2, 4)
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
In [2]:
mesh_shape = [1024, 1024, 1024]
box_size = [1024., 1024., 1024.]
halo_size = 128
snapshots = jnp.linspace(0.1, 1., 3)
@jax.jit
def run_simulation(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
# Create particles
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
axis=-1).reshape([-1, 3])
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement
dx, p, f = lpt(cosmo,
initial_conditions,
particles,
0.1,
halo_size=halo_size,
sharding=sharding)
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = LeapfrogMidpoint()
stepsize_controller = ConstantStepSize()
res = diffeqsolve(term,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(ts=snapshots),
stepsize_controller=stepsize_controller)
return initial_conditions, dx, res.ys, res.stats
In [ ]:
initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
print(f"Solver Stats : {solver_stats}")
In [ ]:
initial_conditions = all_gather(initial_conditions)
lpt_particles = all_gather(lpt_particles)
ode_particles = [all_gather(p) for p in ode_particles]
In [ ]:
from visualize import plot_fields
fields = {"Initial Conditions" : initial_conditions , "LPT Field" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}
for i , field in enumerate(ode_particles):
fields[f"field_{i}"] = cic_paint(jnp.zeros(mesh_shape) , field)
plot_fields(fields)