JaxPM/notebooks/04-MultiGPU_PM_Solvers.ipynb
Wassim KABALAN 695e359f56 jaxdecomp proto (#21)
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
2024-12-20 05:44:02 -05:00

3.7 MiB
Raw Blame History

Multi-GPU Particle Mesh Simulation with Advanced Solvers

Open In Colab

In [ ]:
import os
os.environ["EQX_ON_ERROR"] = "nan"
import jax
import jax.numpy as jnp
import jax_cosmo as jc

from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from functools import partial
from diffrax import ConstantStepSize, LeapfrogMidpoint,Dopri5 , PIDController , ODETerm, SaveAt, diffeqsolve

Note: This notebook requires 8 devices (GPU or TPU).
If you're running on CPU or don't have access to 8 devices,
you can simulate multiple devices by adding the following code at the start BEFORE IMPORTING JAX:

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

Recommended only for debugging. If used, you must probably lower the resolution of the mesh.

In [2]:
assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"

Setting Up Device Mesh and Sharding for Multi-GPU Simulation

This cell configures a 2x4 device mesh across 8 devices and sets up named sharding to distribute data efficiently.

  • Device Mesh: pdims = (2, 4) arranges devices in a 2x4 grid.
  • Sharding with Mesh: Mesh(devices, axis_names=('x', 'y')) assigns the mesh grid axes, which allows flexible mapping of array data across devices.
  • PartitionSpec and NamedSharding: PartitionSpec defines data partitioning across mesh axes ('x', 'y'), and NamedSharding(mesh, P('x', 'y')) specifies this sharding scheme for arrays in the simulation.

More info on Sharding in general in Distributed arrays and automatic parallelization

In [ ]:
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

all_gather = partial(process_allgather, tiled=True)

pdims = (2, 4)
mesh = jax.make_mesh(pdims, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
In [ ]:
@partial(jax.jit , static_argnums=(2,3,4,5))
def run_simulation_with_fields(omega_c, sigma8,mesh_shape,box_size,halo_size , snapshots):
    mesh_shape = (mesh_shape,) * 3
    box_size = (box_size,) * 3
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(
        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape,
                                      box_size,
                                      pk_fn,
                                      seed=jax.random.PRNGKey(0),
                                      sharding=sharding)


    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

    # Initial displacement
    dx, p, f = lpt(cosmo,
                   initial_conditions,
                   a=0.1,
                   order=2,
                   halo_size=halo_size,
                   sharding=sharding)

    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))
    solver = LeapfrogMidpoint()

    stepsize_controller = ConstantStepSize()
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)
    ode_fields = [cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding) for sol in res.ys]
    lpt_field = cic_paint_dx(dx , halo_size=halo_size, sharding=sharding)
    return initial_conditions, lpt_field, ode_fields, res.stats

Large-Scale Simulation Across Multiple Devices

In this cell, we run a large simulation that would not be feasible on a single device. By distributing data across multiple devices, we achieve a higher resolution (mesh_shape = 1024 and box_size = 1000.) with effective boundary handling using a halo_size of 128.

We gather initial conditions and computed fields from all devices for visualization.

In [ ]:
from jaxpm.plotting import plot_fields_single_projection

mesh_shape = 1024
box_size = 1000.
halo_size = 128
snapshots = (0.5 , 1.0)

initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)
ode_fields[-1].block_until_ready()
%timeit initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots);ode_fields[-1].block_until_ready()

initial_conditions_g = all_gather(initial_conditions)
lpt_field_g = all_gather(lpt_field)
ode_fields_g = [all_gather(p) for p in ode_fields]

fields = {"Initial Conditions" : initial_conditions_g , "LPT Field" : lpt_field_g}
for i , field in enumerate(ode_fields_g):
    fields[f"field_{i}"] = field
plot_fields_single_projection(fields,project_axis=0)
45.6 s ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
No description has been provided for this image

This simulation runs in 45 seconds (less than half a second per step), which is impressive for a setup with over one billion particles (since ( 1024^3 \approx 1.07 ) billion). This performance demonstrates the efficiency of distributing data and computation across multiple devices.

Comparing ODE Solvers: Leapfrog vs. Dopri5

Next, we compare the Leapfrog solver with Dopri5 (an adaptive Runge-Kutta method) to observe differences in accuracy and performance for particle evolution.

In [11]:
mesh_shape = 512
box_size = 512.
halo_size = 64
snapshots = (0.5 , 1.0)

initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)
ode_fields[-1].block_until_ready()
%timeit initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots);ode_fields[-1].block_until_ready()
5.04 s ± 9.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [ ]:
mesh_shape = 512
box_size = 512.
halo_size = 64
snapshots = (0.5, 1.0)

@partial(jax.jit , static_argnums=(2,3,4,5))
def run_simulation_with_dopri(omega_c, sigma8,mesh_shape,box_size,halo_size , snapshots):
    mesh_shape = (mesh_shape,) * 3
    box_size = (box_size,) * 3
    # Create a small function to generate the matter power spectrum
    k = jnp.logspace(-4, 1, 128)
    pk = jc.power.linear_matter_power(
        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)

    # Create initial conditions
    initial_conditions = linear_field(mesh_shape,
                                      box_size,
                                      pk_fn,
                                      seed=jax.random.PRNGKey(0),
                                      sharding=sharding)


    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)

    # Initial displacement
    dx, p, f = lpt(cosmo,
                   initial_conditions,
                   a=0.1,
                   order=2,
                   halo_size=halo_size,
                   sharding=sharding)

    # Evolve the simulation forward
    ode_fn = ODETerm(
        make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size))
    solver = Dopri5()

    stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)
    res = diffeqsolve(ode_fn,
                      solver,
                      t0=0.1,
                      t1=1.,
                      dt0=0.01,
                      y0=jnp.stack([dx, p], axis=0),
                      args=cosmo,
                      saveat=SaveAt(ts=snapshots),
                      stepsize_controller=stepsize_controller)
    ode_fields = [cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding) for sol in res.ys]
    lpt_field = cic_paint_dx(dx , halo_size=halo_size, sharding=sharding)
    return initial_conditions, lpt_field, ode_fields, res.stats

initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_dopri(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)
ode_fields[-1].block_until_ready()
%timeit initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_dopri(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots);ode_fields[-1].block_until_ready()

print(f"Solver Stats : {solver_stats}")
4.44 s ± 8.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Solver Stats : {'max_steps': Array(4096, dtype=int32, weak_type=True), 'num_accepted_steps': Array(12, dtype=int32, weak_type=True), 'num_rejected_steps': Array(0, dtype=int32, weak_type=True), 'num_steps': Array(12, dtype=int32, weak_type=True)}
In [10]:
initial_conditions_g = all_gather(initial_conditions)
lpt_field_g = all_gather(lpt_field)
ode_fields_g = [all_gather(p) for p in ode_fields]

fields = {"Initial Conditions" : initial_conditions_g , "LPT Field" : lpt_field_g}
for i , field in enumerate(ode_fields_g):
    fields[f"field_{i}"] = field
plot_fields_single_projection(fields,project_axis=0)
No description has been provided for this image

We can see how easily we can switch solvers here. Although Dopri5 offers adaptive stepping, it didnt yield a significant performance boost over Leapfrog in this case.

Note: Dopri5 uses a PIDController for adaptive stepping, which might face challenges in distributed setups. In my experience, it works well without triggering all-gathers, but make sure to set:

os.environ["EQX_ON_ERROR"] = "nan"

before importing diffrax to handle any errors gracefully.

However, Dopri5 requires more memory than Leapfrog, making a $1024^3$ mesh simulation unfeasible on eight A100 GPUs with 80GB memory each!!. For larger setups, well need more compute resources—this is covered in the final notebook, 05-MultiHost_PM.ipynb.