mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
import os
|
|
|
|
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
|
import jax
|
|
|
|
jax.distributed.initialize()
|
|
rank = jax.process_index()
|
|
size = jax.process_count()
|
|
|
|
from functools import partial
|
|
|
|
import jax.numpy as jnp
|
|
import jax_cosmo as jc
|
|
import numpy as np
|
|
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
|
|
diffeqsolve)
|
|
from jax.experimental.mesh_utils import create_device_mesh
|
|
from jax.experimental.multihost_utils import process_allgather
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
from jaxpm.kernels import interpolate_power_spectrum
|
|
from jaxpm.painting import cic_paint_dx
|
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
|
|
|
all_gather = partial(process_allgather, tiled=True)
|
|
|
|
pdims = (2, 4)
|
|
devices = create_device_mesh(pdims)
|
|
mesh = Mesh(devices, axis_names=('x', 'y'))
|
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
|
|
|
mesh_shape = [512, 512, 512]
|
|
box_size = [500., 500., 1000.]
|
|
halo_size = 64
|
|
snapshots = jnp.linspace(0.1,1.,2)
|
|
|
|
@jax.jit
|
|
def run_simulation(omega_c, sigma8):
|
|
# Create a small function to generate the matter power spectrum
|
|
k = jnp.logspace(-4, 1, 128)
|
|
pk = jc.power.linear_matter_power(
|
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
|
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
|
|
|
# Create initial conditions
|
|
initial_conditions = linear_field(mesh_shape,
|
|
box_size,
|
|
pk_fn,
|
|
seed=jax.random.PRNGKey(0),
|
|
sharding=sharding)
|
|
|
|
# Create particles
|
|
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
|
|
axis=-1).reshape([-1, 3])
|
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
|
|
|
# Initial displacement
|
|
dx, p, _ = lpt(cosmo,
|
|
initial_conditions,
|
|
a=0.1,
|
|
halo_size=halo_size,
|
|
sharding=sharding)
|
|
|
|
# Evolve the simulation forward
|
|
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
|
term = ODETerm(
|
|
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
|
solver = LeapfrogMidpoint()
|
|
|
|
stepsize_controller = ConstantStepSize()
|
|
res = diffeqsolve(term,
|
|
solver,
|
|
t0=0.1,
|
|
t1=1.,
|
|
dt0=0.01,
|
|
y0=jnp.stack([dx, p], axis=0),
|
|
args=cosmo,
|
|
saveat=SaveAt(ts=snapshots),
|
|
stepsize_controller=stepsize_controller)
|
|
|
|
return initial_conditions, dx, res.ys, res.stats
|
|
|
|
|
|
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
|
0.25, 0.8)
|
|
print(f"[{rank}] Simulation completed")
|
|
print(f"[{rank}] Solver stats: {solver_stats}")
|
|
|
|
# Gather the results
|
|
|
|
pm_dict = {"initial_conditions": all_gather(initial_conditions),
|
|
"lpt_displacements": all_gather(lpt_displacements),
|
|
"solver_stats": solver_stats}
|
|
|
|
for i in range(len(ode_solutions)):
|
|
sol = ode_solutions[i]
|
|
pm_dict[f"ode_solution_{i}"] = all_gather(sol)
|
|
|
|
if rank == 0:
|
|
np.savez("multihost_pm.npz", **pm_dict)
|
|
|
|
print(f"[{rank}] Simulation results saved")
|