JaxPM/notebooks/05-MultiHost_PM.ipynb
Wassim KABALAN 695e359f56 jaxdecomp proto (#21)
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
2024-12-20 05:44:02 -05:00

1.9 MiB
Raw Blame History

Multi-Host Particle Mesh Simulation

In this notebook, we extend our Particle Mesh simulation across multiple nodes, enabling simulations at scales not achievable on a single machine. By leveraging distributed GPUs across hosts, we handle larger mesh shapes and box sizes efficiently.

Note: Since theres no direct way to run a multi-host notebook, Ill guide you step by step on how to submit an interactive job from a script.

To run a multi-host simulation, you first need to allocate a job with salloc. This command requests resources on an HPC cluster.

Note: You can alternatively use sbatch with a SLURM script to submit the job. The exact salloc parameters may vary depending on your specific HPC cluster configuration.

In [ ]:
!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00  --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 &

Note: These salloc parameters are configured for the Jean Zay supercomputer in France. Adaptations might be necessary if using a different HPC cluster.

A few hours later

Use !squeue -u $USER -o "%i %D %b" to check the JOB ID and verify your resource allocation.

In this example, weve been allocated 32 GPUs split across 4 nodes.

In [19]:
!squeue -u $USER -o "%i %D %b"
JOBID NODES TRES_PER_NODE
467745 4 gres/gpu:8

Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:

In [ ]:
import os
del os.environ['VSCODE_PROXY_URI']
del os.environ['NO_PROXY']
del os.environ['no_proxy']

Checking Available Compute Resources

Run the following command to initialize JAX distributed computing and display the devices available for this job:

In [21]:
!srun --jobid=467745 -n 32 python -c "import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None"
[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7), CudaDevice(id=8), CudaDevice(id=9), CudaDevice(id=10), CudaDevice(id=11), CudaDevice(id=12), CudaDevice(id=13), CudaDevice(id=14), CudaDevice(id=15), CudaDevice(id=16), CudaDevice(id=17), CudaDevice(id=18), CudaDevice(id=19), CudaDevice(id=20), CudaDevice(id=21), CudaDevice(id=22), CudaDevice(id=23), CudaDevice(id=24), CudaDevice(id=25), CudaDevice(id=26), CudaDevice(id=27), CudaDevice(id=28), CudaDevice(id=29), CudaDevice(id=30), CudaDevice(id=31)]

Running the Multi-Host Simulation Script

Run the simulation script across 32 processes:

!srun --jobid=467745 -n 32 python 05-MultiHost_PM.py --mesh_shape 1024 1024 1024 --box_size 1000. 1000. 1000. --halo_size 128 -s leapfrog --pdims 16 2

The script, located in the same path as this notebook, is named 05-MultiHost_PM.py.

Multi-Host Simulation Script with Arguments

This script is nearly identical to the single-host version, with the main addition being the call to jax.distributed.initialize() at the start, enabling multi-host parallelism. Heres a breakdown of the key arguments:

  • --pdims (-p): Specifies processor grid dimensions as two integers, like 16 2 for 16 x 2 device mesh (default is [1, jax.devices()]).
  • --mesh_shape (-m): Defines the simulation mesh shape as three integers (default is [512, 512, 512]).
  • --box_size (-b): Sets the physical box size of the simulation as three floating-point values, e.g., 1000. 1000. 1000. (default is [500.0, 500.0, 500.0]).
  • --halo_size (-H): Specifies the halo size for boundary overlap across nodes (default is 64).
  • --solver (-s): Chooses the ODE solver (leapfrog or dopri8). The leapfrog solver uses a fixed step size, while dopri8 is an adaptive Runge-Kutta solver with a PID controller (default is leapfrog).
  • --snapthots (-st) : Number of snapshots to save (warning, increases memory usage)

The script also saves results across nodes.

In [ ]:
import subprocess

# Define parameters as variables
jobid = "467745"
num_processes = 32
script_name = "05-MultiHost_PM.py"
mesh_shape = (1024, 1024, 1024)
box_size = (1000., 1000., 1000.)
halo_size = 128
solver = "leapfrog"
pdims = (16, 2)
snapshots = 2

# Build the command as a list, incorporating variables
command = [
    "srun",
    f"--jobid={jobid}",
    "-n", str(num_processes),
    "python", script_name,
    "--mesh_shape", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),
    "--box_size", str(box_size[0]), str(box_size[1]), str(box_size[2]),
    "--halo_size", str(halo_size),
    "-s", solver,
    "--pdims", str(pdims[0]), str(pdims[1]),
    "--snapshots", str(snapshots)
]

# Execute the command as a subprocess
subprocess.run(command)

Loading and Visualizing Results

After running the multi-host simulation, we load the saved results from disk:

  • initial_conditions.npy: Initial conditions for the simulation.
  • lpt_displacements.npy: Linear perturbation displacements.
  • ode_solution_0.npy and ode_solution_1.npy: Solutions from the ODE solver at each snapshot.

We then use plot_fields_single_projection to visualize these fields and observe the results across multiple snapshots.

In [1]:
import numpy as np

initial_conditions = np.load('fields/initial_conditions.npy')
lpt_displacements = np.load('fields/lpt_displacements.npy')
ode_solution_0 = np.load('fields/ode_solution_0.npy')
ode_solution_1 = np.load('fields/ode_solution_1.npy')
In [2]:
from jaxpm.plotting import plot_fields_single_projection
fields = {
    "Initial Conditions": initial_conditions,
    "LPT Field": lpt_displacements,
    "ODE Solution 0": ode_solution_0,
    "ODE Solution 1": ode_solution_1
}
plot_fields_single_projection(fields)
No description has been provided for this image