* Use cosmo as arg for the ODE function * Update examples * format * notebook update * fix tests * add correct annotations for weights in painting and warning for cic_paint in distributed pm * update test_against_fpm * update distributed tests and add jacfwd jacrev and vmap tests * format * add Caveats to notebook readme * final touches * update Growth.py to allow using FastPM solver * fix 2D painting when input is (X , Y , 2) shape * update cic read halo size and notebooks examples * Allow env variable control of caching in growth * Format * update test jax version * update notebooks/03-MultiGPU_PM_Halo.ipynb * update numpy install in wf * update tolerance :) * reorganize install in test workflow * update tests * add mpi4py * update tests.yml * update tests * update wf * format * make normal_field signature consistent with jax.random.normal * update by default normal_field dtype to match JAX * format * debug test workflow * format * debug test workflow * updating tests * fix accuracy * fixed tolerance * adding caching * Update conftest.py * Update tolerance and precision settings in distributed PM tests * revererting back changes to growth.py --------- Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
1.2 MiB
Multi-Host Particle Mesh Simulation¶
In this notebook, we extend our Particle Mesh simulation across multiple nodes, enabling simulations at scales not achievable on a single machine. By leveraging distributed GPUs across hosts, we handle larger mesh shapes and box sizes efficiently.
Note: Since there’s no direct way to run a multi-host notebook, I’ll guide you step by step on how to submit an interactive job from a script.
To run a multi-host simulation, you first need to allocate a job with salloc
. This command requests resources on an HPC cluster.
Note: You can alternatively use
sbatch
with a SLURM script to submit the job. The exactsalloc
parameters may vary depending on your specific HPC cluster configuration.
!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 &
Note: These
salloc
parameters are configured for the Jean Zay supercomputer in France. Adaptations might be necessary if using a different HPC cluster.
A few hours later
Use !squeue -u $USER -o "%i %D %b"
to check the JOB ID and verify your resource allocation.
In this example, we’ve been allocated 32 GPUs split across 4 nodes.
!squeue -u $USER -o "%i %D %b"
Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:
import os
del os.environ['VSCODE_PROXY_URI']
del os.environ['NO_PROXY']
del os.environ['no_proxy']
Checking Available Compute Resources¶
Run the following command to initialize JAX distributed computing and display the devices available for this job:
!srun --jobid=467745 -n 32 python -c "import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None"
Running the Multi-Host Simulation Script¶
Run the simulation script across 32 processes:
!srun --jobid=467745 -n 32 python 05-MultiHost_PM.py --mesh_shape 1024 1024 1024 --box_size 1000. 1000. 1000. --halo_size 128 -s leapfrog --pdims 16 2
The script, located in the same path as this notebook, is named 05-MultiHost_PM.py.
Multi-Host Simulation Script with Arguments¶
This script is nearly identical to the single-host version, with the main addition being the call to jax.distributed.initialize()
at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:
--pdims
(-p
): Specifies processor grid dimensions as two integers, like16 2
for 16 x 2 device mesh (default is[1, jax.devices()]
).--mesh_shape
(-m
): Defines the simulation mesh shape as three integers (default is[512, 512, 512]
).--box_size
(-b
): Sets the physical box size of the simulation as three floating-point values, e.g.,1000. 1000. 1000.
(default is[500.0, 500.0, 500.0]
).--halo_size
(-H
): Specifies the halo size for boundary overlap across nodes (default is64
).--solver
(-s
): Chooses the ODE solver (leapfrog
ordopri8
). Theleapfrog
solver uses a fixed step size, whiledopri8
is an adaptive Runge-Kutta solver with a PID controller (default isleapfrog
).--snapthots
(-st
) : Number of snapshots to save (warning, increases memory usage)
The script also saves results across nodes.
import subprocess
# Define parameters as variables
jobid = "467745"
num_processes = 32
script_name = "05-MultiHost_PM.py"
mesh_shape = (1024, 1024, 1024)
box_size = (1000., 1000., 1000.)
halo_size = 128
solver = "leapfrog"
pdims = (16, 2)
snapshots = 2
# Build the command as a list, incorporating variables
command = [
"srun",
f"--jobid={jobid}",
"-n", str(num_processes),
"python", script_name,
"--mesh_shape", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),
"--box_size", str(box_size[0]), str(box_size[1]), str(box_size[2]),
"--halo_size", str(halo_size),
"-s", solver,
"--pdims", str(pdims[0]), str(pdims[1]),
"--snapshots", str(snapshots)
]
# Execute the command as a subprocess
subprocess.run(command)
Loading and Visualizing Results¶
After running the multi-host simulation, we load the saved results from disk:
initial_conditions.npy
: Initial conditions for the simulation.lpt_displacements.npy
: Linear perturbation displacements.ode_solution_0.npy
andode_solution_1.npy
: Solutions from the ODE solver at each snapshot.
We then use plot_fields_single_projection
to visualize these fields and observe the results across multiple snapshots.
import numpy as np
initial_conditions = np.load('fields/initial_conditions.npy')
lpt_displacements = np.load('fields/lpt_displacements.npy')
ode_solution_0 = np.load('fields/ode_solution_0.npy')
ode_solution_1 = np.load('fields/ode_solution_1.npy')
from jaxpm.plotting import plot_fields_single_projection
fields = {
"Initial Conditions": initial_conditions,
"LPT Field": lpt_displacements,
"ODE Solution 0": ode_solution_0,
"ODE Solution 1": ode_solution_1
}
plot_fields_single_projection(fields)