* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
1.9 MiB
Multi-Host Particle Mesh Simulation¶
In this notebook, we extend our Particle Mesh simulation across multiple nodes, enabling simulations at scales not achievable on a single machine. By leveraging distributed GPUs across hosts, we handle larger mesh shapes and box sizes efficiently.
Note: Since there’s no direct way to run a multi-host notebook, I’ll guide you step by step on how to submit an interactive job from a script.
To run a multi-host simulation, you first need to allocate a job with salloc
. This command requests resources on an HPC cluster.
Note: You can alternatively use
sbatch
with a SLURM script to submit the job. The exactsalloc
parameters may vary depending on your specific HPC cluster configuration.
!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 &
Note: These
salloc
parameters are configured for the Jean Zay supercomputer in France. Adaptations might be necessary if using a different HPC cluster.
A few hours later
Use !squeue -u $USER -o "%i %D %b"
to check the JOB ID and verify your resource allocation.
In this example, we’ve been allocated 32 GPUs split across 4 nodes.
!squeue -u $USER -o "%i %D %b"
Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:
import os
del os.environ['VSCODE_PROXY_URI']
del os.environ['NO_PROXY']
del os.environ['no_proxy']
Checking Available Compute Resources¶
Run the following command to initialize JAX distributed computing and display the devices available for this job:
!srun --jobid=467745 -n 32 python -c "import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None"
Running the Multi-Host Simulation Script¶
Run the simulation script across 32 processes:
!srun --jobid=467745 -n 32 python 05-MultiHost_PM.py --mesh_shape 1024 1024 1024 --box_size 1000. 1000. 1000. --halo_size 128 -s leapfrog --pdims 16 2
The script, located in the same path as this notebook, is named 05-MultiHost_PM.py.
Multi-Host Simulation Script with Arguments¶
This script is nearly identical to the single-host version, with the main addition being the call to jax.distributed.initialize()
at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:
--pdims
(-p
): Specifies processor grid dimensions as two integers, like16 2
for 16 x 2 device mesh (default is[1, jax.devices()]
).--mesh_shape
(-m
): Defines the simulation mesh shape as three integers (default is[512, 512, 512]
).--box_size
(-b
): Sets the physical box size of the simulation as three floating-point values, e.g.,1000. 1000. 1000.
(default is[500.0, 500.0, 500.0]
).--halo_size
(-H
): Specifies the halo size for boundary overlap across nodes (default is64
).--solver
(-s
): Chooses the ODE solver (leapfrog
ordopri8
). Theleapfrog
solver uses a fixed step size, whiledopri8
is an adaptive Runge-Kutta solver with a PID controller (default isleapfrog
).--snapthots
(-st
) : Number of snapshots to save (warning, increases memory usage)
The script also saves results across nodes.
import subprocess
# Define parameters as variables
jobid = "467745"
num_processes = 32
script_name = "05-MultiHost_PM.py"
mesh_shape = (1024, 1024, 1024)
box_size = (1000., 1000., 1000.)
halo_size = 128
solver = "leapfrog"
pdims = (16, 2)
snapshots = 2
# Build the command as a list, incorporating variables
command = [
"srun",
f"--jobid={jobid}",
"-n", str(num_processes),
"python", script_name,
"--mesh_shape", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),
"--box_size", str(box_size[0]), str(box_size[1]), str(box_size[2]),
"--halo_size", str(halo_size),
"-s", solver,
"--pdims", str(pdims[0]), str(pdims[1]),
"--snapshots", str(snapshots)
]
# Execute the command as a subprocess
subprocess.run(command)
Loading and Visualizing Results¶
After running the multi-host simulation, we load the saved results from disk:
initial_conditions.npy
: Initial conditions for the simulation.lpt_displacements.npy
: Linear perturbation displacements.ode_solution_0.npy
andode_solution_1.npy
: Solutions from the ODE solver at each snapshot.
We then use plot_fields_single_projection
to visualize these fields and observe the results across multiple snapshots.
import numpy as np
initial_conditions = np.load('fields/initial_conditions.npy')
lpt_displacements = np.load('fields/lpt_displacements.npy')
ode_solution_0 = np.load('fields/ode_solution_0.npy')
ode_solution_1 = np.load('fields/ode_solution_1.npy')
from jaxpm.plotting import plot_fields_single_projection
fields = {
"Initial Conditions": initial_conditions,
"LPT Field": lpt_displacements,
"ODE Solution 0": ode_solution_0,
"ODE Solution 1": ode_solution_1
}
plot_fields_single_projection(fields)