JaxPM/notebooks/02-MultiGPU_PM.ipynb
2024-10-26 22:49:17 +02:00

7.9 KiB

Open In Colab

In [8]:
!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git
!pip install diffrax
In [1]:
import os
import jax
import jax.numpy as jnp
import jax_cosmo as jc

from jax.experimental.ode import odeint
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint
from jaxpm.pm import linear_field, lpt, make_ode_fn
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib
In [ ]:
assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
In [ ]:
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from functools import partial

all_gather = partial(process_allgather, tiled=True)

pdims = (2, 4)
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
In [2]:
mesh_shape = [1024, 1024, 1024]
box_size = [1024., 1024., 1024.]
halo_size = 128
snapshots = jnp.linspace(0.1, 1., 3)


@jax.jit
def run_simulation(omega_c, sigma8):
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(
        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape,
                                      box_size,
                                      pk_fn,
                                      seed=jax.random.PRNGKey(0),
                                      sharding=sharding)

    # Create particles
    particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
                          axis=-1).reshape([-1, 3])
    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

    # Initial displacement
    dx, p, f = lpt(cosmo,
                   initial_conditions,
                   particles,
                   0.1,
                   halo_size=halo_size,
                   sharding=sharding)

    # Evolve the simulation forward
    ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
    term = ODETerm(
        lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
    solver = LeapfrogMidpoint()

    stepsize_controller = ConstantStepSize()
    res = diffeqsolve(term,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)

    return initial_conditions, dx, res.ys, res.stats
In [ ]:
initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
print(f"Solver Stats : {solver_stats}")
In [ ]:
initial_conditions = all_gather(initial_conditions)
lpt_particles = all_gather(lpt_particles)
ode_particles = [all_gather(p) for p in ode_particles]
In [ ]:
from visualize import plot_fields

fields = {"Initial Conditions" : initial_conditions , "LPT Field" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}
for i , field in enumerate(ode_particles):
    fields[f"field_{i}"] = cic_paint(jnp.zeros(mesh_shape) , field)
plot_fields(fields)
shape of grid_mesh: (256, 256, 256)