1.9 MiB
Multi-Host Particle Mesh Simulation¶
In this notebook, we extend our Particle Mesh simulation across multiple nodes, enabling simulations at scales not achievable on a single machine. By leveraging distributed GPUs across hosts, we handle larger mesh shapes and box sizes efficiently.
Note: Since there’s no direct way to run a multi-host notebook, I’ll guide you step by step on how to submit an interactive job from a script.
To run a multi-host simulation, you first need to allocate a job with salloc
. This command requests resources on an HPC cluster.
Note: You can alternatively use
sbatch
with a SLURM script to submit the job. The exactsalloc
parameters may vary depending on your specific HPC cluster configuration.
!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 &
Note: These
salloc
parameters are configured for the Jean Zay supercomputer in France. Adaptations might be necessary if using a different HPC cluster.
A few hours later
Use !squeue -u $USER -o "%i %D %b"
to check the JOB ID and verify your resource allocation.
In this example, we’ve been allocated 32 GPUs split across 4 nodes.
!squeue -u $USER -o "%i %D %b"
Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:
import os
del os.environ['VSCODE_PROXY_URI']
del os.environ['NO_PROXY']
del os.environ['no_proxy']
Checking Available Compute Resources¶
Run the following command to initialize JAX distributed computing and display the devices available for this job:
!srun --jobid=467745 -n 32 python -c "import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None"
Running the Multi-Host Simulation Script¶
Run the simulation script across 32 processes:
!srun --jobid=467745 -n 32 python 05-MultiHost_PM.py --mesh_shape 1024 1024 1024 --box_size 1000. 1000. 1000. --halo_size 128 -s leapfrog --pdims 16 2
The script, located in the same path as this notebook, is named 05-MultiHost_PM.py.
Multi-Host Simulation Script with Arguments¶
This script is nearly identical to the single-host version, with the main addition being the call to jax.distributed.initialize()
at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:
--pdims
(-p
): Specifies processor grid dimensions as two integers, like16 2
for 16 x 2 device mesh (default is[1, jax.devices()]
).--mesh_shape
(-m
): Defines the simulation mesh shape as three integers (default is[512, 512, 512]
).--box_size
(-b
): Sets the physical box size of the simulation as three floating-point values, e.g.,1000. 1000. 1000.
(default is[500.0, 500.0, 500.0]
).--halo_size
(-H
): Specifies the halo size for boundary overlap across nodes (default is64
).--solver
(-s
): Chooses the ODE solver (leapfrog
ordopri8
). Theleapfrog
solver uses a fixed step size, whiledopri8
is an adaptive Runge-Kutta solver with a PID controller (default isleapfrog
).
The script also saves results across nodes.
!srun --jobid=467745 -n 32 python 05-MultiHost_PM.py --mesh_shape 1024 1024 1024 --box_size 1000. 1000. 1000. --halo_size 128 -s leapfrog --pdims 16 2
Loading and Visualizing Results¶
After running the multi-host simulation, we load the saved results from disk:
initial_conditions.npy
: Initial conditions for the simulation.lpt_displacements.npy
: Linear perturbation displacements.ode_solution_0.npy
andode_solution_1.npy
: Solutions from the ODE solver at each snapshot.
We then use plot_fields_single_projection
to visualize these fields and observe the results across multiple snapshots.
import numpy as np
initial_conditions = np.load('initial_conditions.npy')
lpt_displacements = np.load('lpt_displacements.npy')
ode_solution_0 = np.load('ode_solution_0.npy')
ode_solution_1 = np.load('ode_solution_1.npy')
from visualize import plot_fields_single_projection
fields = {
"Initial Conditions": initial_conditions,
"LPT Field": lpt_displacements,
"ODE Solution 0": ode_solution_0,
"ODE Solution 1": ode_solution_1
}
plot_fields_single_projection(fields)