This commit is contained in:
Wassim KABALAN 2024-08-02 23:39:09 +02:00
parent 9af4659c81
commit 831291c1f9
8 changed files with 790 additions and 114 deletions

249
benchmarks/bench_pm.py Normal file
View file

@ -0,0 +1,249 @@
import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
import argparse
import time
from hpc_plotter.timer import Timer
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, Tsit5, diffeqsolve)
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
def run_simulation(mesh_shape,
box_size,
halo_size,
solver_choice,
iterations,
pdims=None):
@jax.jit
def simulate(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)
# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0))
# Create particles
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
if solver_choice == "Dopri5":
solver = Dopri5()
elif solver_choice == "LeapfrogMidpoint":
solver = LeapfrogMidpoint()
elif solver_choice == "Tsit5":
solver = Tsit5()
elif solver_choice == "lpt":
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
print(f"TYPE of lpt_field: {type(lpt_field)}")
return lpt_field, {"num_steps": 0}
else:
raise ValueError(
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
if solver_choice == "Dopri5" or solver_choice == "Tsit5":
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler":
stepsize_controller = ConstantStepSize()
res = diffeqsolve(term,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(t1=True),
stepsize_controller=stepsize_controller)
# Return the simulation volume at requested
state = res.ys[-1]
final_field = cic_paint_dx(state[0], halo_size=halo_size)
return final_field, res.stats
def run():
# Warm start
chrono_fun = Timer()
RangePush("warmup")
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
RangePop()
sync_global_devices("warmup")
for i in range(iterations):
RangePush(f"sim iter {i}")
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
RangePop()
return final_field, stats, chrono_fun
if jax.device_count() > 1:
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
with mesh:
# Warm start
final_field, stats, chrono_fun = run()
else:
final_field, stats, chrono_fun = run()
return final_field, stats, chrono_fun
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='JAX Cosmo Simulation Benchmark')
parser.add_argument('-m',
'--mesh_size',
type=int,
help='Mesh size',
required=True)
parser.add_argument('-b',
'--box_size',
type=float,
help='Box size',
required=True)
parser.add_argument('-p',
'--pdims',
type=str,
help='Processor dimensions',
default=None)
parser.add_argument('-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],)
parser.add_argument('-hs',
'--halo_size',
type=int,
help='Halo size',
required=True)
parser.add_argument('-s',
'--solver',
type=str,
help='Solver',
choices=[
"Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5",
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm",
"lpt"
],
required=True)
parser.add_argument('-i',
'--iterations',
type=int,
help='Number of iterations',
default=10)
parser.add_argument('-o',
'--output_path',
type=str,
help='Output path',
default=".")
parser.add_argument('-f',
'--save_fields',
action='store_true',
help='Save fields')
parser.add_argument('-n',
'--nodes',
type=int,
help='Number of nodes',
default=1)
args = parser.parse_args()
mesh_size = args.mesh_size
box_size = [args.box_size] * 3
halo_size = args.halo_size
solver_choice = args.solver
iterations = args.iterations
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
print(f"solver choice: {solver_choice}")
match solver_choice:
case "Dopri5" | "dopri5"| "d5":
solver_choice = "Dopri5"
case "Tsit5"| "tsit5"| "t5":
solver_choice = "Tsit5"
case "LeapfrogMidpoint"| "leapfrogmidpoint"| "lfm":
solver_choice = "LeapfrogMidpoint"
case "lpt":
solver_choice = "lpt"
case _:
raise ValueError(
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt"
)
if args.precision == "float32":
jax.config.update("jax_enable_x64", False)
elif args.precision == "float64":
jax.config.update("jax_enable_x64", True)
if args.pdims:
pdims = tuple(map(int, args.pdims.split("x")))
else:
pdims = (1, 1)
mesh_shape = [mesh_size] * 3
final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims)
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
metadata = {
'rank': rank,
'function_name': f'JAXPM-{solver_choice}',
'precision': args.precision,
'x': str(mesh_size),
'y': str(mesh_size),
'z': str(stats["num_steps"]),
'px': str(pdims[0]),
'py': str(pdims[1]),
'backend': 'NCCL',
'nodes': str(args.nodes)
}
# Print the results to a CSV file
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
# Save the final field
nb_gpus = jax.device_count()
pdm_str = f"{pdims[0]}x{pdims[1]}"
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
os.makedirs(field_folder, exist_ok=True)
with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i , time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n")
if args.save_fields:
np.save(f'{field_folder}/final_field_0_{rank}.npy',
final_field.addressable_data(0))
print(f"Finished! ")
print(f"Stats {stats}")
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
print(f"Saving field and logs in {field_folder}")

131
benchmarks/bench_pmwd.py Normal file
View file

@ -0,0 +1,131 @@
import os
# Change JAX GPU memory preallocation fraction
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
import jax
import argparse
import numpy as np
import matplotlib.pyplot as plt
from pmwd import (
Configuration,
Cosmology, SimpleLCDM,
boltzmann, linear_power, growth,
white_noise, linear_modes,
lpt, nbody, scatter
)
from pmwd.pm_util import fftinv
from pmwd.spec_util import powspec
from pmwd.vis_util import simshow
from hpc_plotter.timer import Timer
# Simulation configuration
def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver , iterations):
@jax.jit
def simulate(omega_m, sigma8):
conf = Configuration(ptcl_spacing, ptcl_grid_shape=ptcl_grid_shape, mesh_shape=1,lpt_order=1,a_nbody_maxstep=1/91)
print(conf)
print(f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.')
cosmo = Cosmology(conf, A_s_1e9=2.0, n_s=0.96, Omega_m=omega_m, Omega_b=sigma8, h=0.7)
print(cosmo)
# Boltzmann calculation
cosmo = boltzmann(cosmo, conf)
print("Boltzmann calculation completed.")
# Generate white noise field and scale with the linear power spectrum
seed = 0
modes = white_noise(seed, conf)
modes = linear_modes(modes, cosmo, conf)
print("Linear modes generated.")
# Solve LPT at some early time
ptcl, obsvbl = lpt(modes, cosmo, conf)
print("LPT solved.")
if solver == "lfm":
# N-body time integration from LPT initial conditions
ptcl, obsvbl = jax.block_until_ready(nbody(ptcl, obsvbl, cosmo, conf))
print("N-body time integration completed.")
# Scatter particles to mesh to get the density field
dens = scatter(ptcl, conf)
return dens
chrono_timer = Timer()
final_field = chrono_timer.chrono_jit(simulate, 0.3, 0.05)
for _ in range(iterations):
final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05)
return final_field , chrono_timer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PMWD Simulation')
parser.add_argument('-m', '--mesh_size', type=int, help='Mesh size', required=True)
parser.add_argument('-b', '--box_size', type=float, help='Box size', required=True)
parser.add_argument('-i', '--iterations', type=int, help='Number of iterations', default=10)
parser.add_argument('-o', '--output_path', type=str, help='Output path', default=".")
parser.add_argument('-f', '--save_fields', action='store_true', help='Save fields')
parser.add_argument('-s', '--solver', type=str, help='Solver', choices=["lfm" , "lpt"])
parser.add_argument('-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],)
args = parser.parse_args()
mesh_shape = [args.mesh_size] * 3
ptcl_spacing = args.box_size /args.mesh_size
iterations = args.iterations
solver = args.solver
output_path = args.output_path
if args.precision == "float32":
jax.config.update("jax_enable_x64", False)
elif args.precision == "float64":
jax.config.update("jax_enable_x64", True)
os.makedirs(output_path, exist_ok=True)
final_field , chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, solver, iterations)
print("PMWD simulation completed.")
metadata = {
'rank': 0,
'function_name': f'PMWD-{solver}',
'precision': args.precision,
'x': str(mesh_shape[0]),
'y': str(mesh_shape[0]),
'z': str(mesh_shape[0]),
'px': "1",
'py': "1",
'backend': 'NCCL',
'nodes': "1"
}
chrono_fun.print_to_csv(f"{output_path}/pmwd.csv", **metadata)
field_folder = f"{output_path}/final_field/pmwd/1/{args.mesh_size}_{int(args.box_size)}/1x1/{args.solver}/halo_0"
os.makedirs(field_folder, exist_ok=True)
with open(f"{field_folder}/pmwd.log", "w") as f:
f.write(f"PMWD simulation completed.\n")
f.write(f"Args : {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i , time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
if args.save_fields:
np.save(f"{field_folder}/final_field_0_0.npy", final_field)
print("Fields saved.")
print(f"saving to {output_path}/pmwd.csv")
print(f"saving field and logs to {field_folder}/pmwd.log")

View file

@ -0,0 +1,183 @@
#!/bin/bash
##########################################
## SELECT EITHER tkc@a100 OR tkc@v100 ##
##########################################
#SBATCH --account tkc@a100
##########################################
#SBATCH --job-name=1N-FFT-Mesh # nom du job
# Il est possible d'utiliser une autre partition que celle par default
# en activant l'une des 5 directives suivantes :
##########################################
## SELECT EITHER a100 or v100-32g ##
##########################################
#SBATCH -C a100
##########################################
#******************************************
##########################################
## SELECT Number of nodes and GPUs per node
## For A100 ntasks-per-node and gres=gpu should be 8
## For V100 ntasks-per-node and gres=gpu should be 4
##########################################
#SBATCH --nodes=1 # nombre de noeud
#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
#SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
##########################################
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
##########################################
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
##########################################
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
#SBATCH --hint=nomultithread # hyperthreading desactive
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
##SBATCH --qos=qos_gpu-dev
## SBATCH --exclusive # ressources dediees
# Nettoyage des modules charges en interactif et herites par defaut
num_nodes=$SLURM_JOB_NUM_NODES
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
OUTPUT_FOLDER_ARGS=1
# Calculate the number of GPUs
nb_gpus=$(( num_nodes * num_gpu_per_node))
module purge
echo "Job constraint: $SLURM_JOB_CONSTRAINT"
echo "Job partition: $SLURM_JOB_PARTITION"
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
# pour avoir acces aux modules compatibles avec cette partition
if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
module load cpuarch/amd
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
gpu_name=a100
else
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
gpu_name=v100
fi
# Chargement des modules
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus"
export ENABLE_PERFO_STEP=NVTX
export MPI4JAX_USE_CUDA_MPI=1
function profile_python() {
if [ $# -lt 1 ]; then
echo "Usage: profile_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="prof_traces/$script_name"
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="prof_traces/$script_name/$args"
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
mkdir -p "$report_dir"
srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
function run_python() {
if [ $# -lt 1 ]; then
echo "Usage: run_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="traces/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="traces/$script_name/$args"
fi
mkdir -p "$output_dir"
srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
# run or profile
function slaunch() {
run_python "$@"
}
function plaunch() {
profile_python "$@"
}
# Echo des commandes lancees
set -x
# Pour ne pas utiliser le /tmp
export TMPDIR=$JOBSCRATCH
# Pour contourner un bogue dans les versions actuelles de Nsight Systems
# il est également nécessaire de créer un lien symbolique permettant de
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
ln -s $JOBSCRATCH /tmp/nvidia
declare -A pdims_table
# Define the table
pdims_table[1]="1x1"
pdims_table[4]="2x2 1x4 4x1"
pdims_table[8]="2x4 1x8 8x1 4x2"
pdims_table[16]="4x4 1x16 16x1"
pdims_table[32]="4x8 8x4 1x32 32x1"
pdims_table[64]="8x8 16x4 1x64 64x1"
pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2"
pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4"
# mpch=(128 256 512 1024 2048 4096)
grid=(256 512 1024 2048 4096)
precisions=(float32 float64)
pdim="${pdims_table[$nb_gpus]}"
solvers=(lpt lfm)
echo "pdims: $pdim"
# Check if pdims is not empty
if [ -z "$pdim" ]; then
echo "pdims is empty"
echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160"
echo "Number of nodes selected: $num_nodes"
echo "Number of gpus per node: $num_gpu_per_node"
exit 1
fi
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
out_dir="pm_prof/$gpu_name/$nb_gpus"
echo "Output dir is : $out_dir"
for pr in "${precisions[@]}"; do
for g in "${grid[@]}"; do
for solver in "${solvers[@]}"; do
for p in $pdim; do
halo_size=$((g / 4))
slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes
done
done
done
done

View file

@ -0,0 +1,184 @@
#!/bin/bash
##########################################
## SELECT EITHER tkc@a100 OR tkc@v100 ##
##########################################
#SBATCH --account tkc@v100
##########################################
#SBATCH --job-name=V100Particle-Mesh # nom du job
# Il est possible d'utiliser une autre partition que celle par default
# en activant l'une des 5 directives suivantes :
##########################################
## SELECT EITHER a100 or v100-32g ##
##########################################
#SBATCH -C v100-32g
##########################################
#******************************************
##########################################
## SELECT Number of nodes and GPUs per node
## For A100 ntasks-per-node and gres=gpu should be 8
## For V100 ntasks-per-node and gres=gpu should be 4
##########################################
#SBATCH --nodes=1 # nombre de noeud
#SBATCH --ntasks-per-node=4 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
#SBATCH --gres=gpu:4 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
##########################################
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
##########################################
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
##########################################
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
#SBATCH --hint=nomultithread # hyperthreading desactive
#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS)
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
#SBATCH --qos=qos_gpu-dev
#SBATCH --exclusive # ressources dediees
# Nettoyage des modules charges en interactif et herites par defaut
num_nodes=$SLURM_JOB_NUM_NODES
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
OUTPUT_FOLDER_ARGS=1
# Calculate the number of GPUs
nb_gpus=$(( num_nodes * num_gpu_per_node))
module purge
echo "Job constraint: $SLURM_JOB_CONSTRAINT"
echo "Job partition: $SLURM_JOB_PARTITION"
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
# pour avoir acces aux modules compatibles avec cette partition
if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
module load cpuarch/amd
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
gpu_name=a100
else
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
gpu_name=v100
fi
# Chargement des modules
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus"
export EQX_ON_ERROR=nan
export CUDA_ALLOC=1
export ENABLE_PERFO_STEP=NVTX
export MPI4JAX_USE_CUDA_MPI=1
function profile_python() {
if [ $# -lt 1 ]; then
echo "Usage: profile_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="prof_traces/$script_name"
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="prof_traces/$script_name/$args"
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
mkdir -p "$report_dir"
srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
function run_python() {
if [ $# -lt 1 ]; then
echo "Usage: run_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="traces/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="traces/$script_name/$args"
fi
mkdir -p "$output_dir"
srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
# run or profile
function slaunch() {
run_python "$@"
}
function plaunch() {
profile_python "$@"
}
# Echo des commandes lancees
set -x
# Pour ne pas utiliser le /tmp
export TMPDIR=$JOBSCRATCH
# Pour contourner un bogue dans les versions actuelles de Nsight Systems
# il est également nécessaire de créer un lien symbolique permettant de
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
ln -s $JOBSCRATCH /tmp/nvidia
declare -A pdims_table
# Define the table
pdims_table[1]="1x1"
pdims_table[4]="2x2 1x4 4x1"
pdims_table[8]="2x4 1x8 8x1 4x2"
pdims_table[16]="4x4 1x16 16x1"
pdims_table[32]="4x8 8x4 1x32 32x1"
pdims_table[64]="8x8 16x4 1x64 64x1"
pdims_table[128]="8x16 16x8 4x32 1x128 128x1"
pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4"
# mpch=(128 256 512 1024 2048 4096)
grid=(256 512 1024 2048 4096)
precisions=(float32 float64)
pdim="${pdims_table[$nb_gpus]}"
solvers=(lpt lfm)
echo "pdims: $pdim"
# Check if pdims is not empty
if [ -z "$pdim" ]; then
echo "pdims is empty"
echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160"
echo "Number of nodes selected: $num_nodes"
echo "Number of gpus per node: $num_gpu_per_node"
exit 1
fi
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
out_dir="pm_prof/$gpu_name/$nb_gpus"
echo "Output dir is : $out_dir"
for pr in "${precisions[@]}"; do
for g in "${grid[@]}"; do
for solver in "${solvers[@]}"; do
for p in $pdim; do
halo_size=$((g / 4))
slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes
done
done
done
done

165
benchmarks/pmwd_a100.slurm Normal file
View file

@ -0,0 +1,165 @@
#!/bin/bash
##########################################
## SELECT EITHER tkc@a100 OR tkc@v100 ##
##########################################
#SBATCH --account tkc@a100
##########################################
#SBATCH --job-name=1N-FFT-Mesh # nom du job
# Il est possible d'utiliser une autre partition que celle par default
# en activant l'une des 5 directives suivantes :
##########################################
## SELECT EITHER a100 or v100-32g ##
##########################################
#SBATCH -C a100
##########################################
#******************************************
##########################################
## SELECT Number of nodes and GPUs per node
## For A100 ntasks-per-node and gres=gpu should be 8
## For V100 ntasks-per-node and gres=gpu should be 4
##########################################
#SBATCH --nodes=1 # nombre de noeud
#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
#SBATCH --gres=gpu:1 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
##########################################
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
##########################################
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
##########################################
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
#SBATCH --hint=nomultithread # hyperthreading desactive
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
##SBATCH --qos=qos_gpu-dev
## SBATCH --exclusive # ressources dediees
# Nettoyage des modules charges en interactif et herites par defaut
num_nodes=$SLURM_JOB_NUM_NODES
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
OUTPUT_FOLDER_ARGS=1
# Calculate the number of GPUs
nb_gpus=$(( num_nodes * num_gpu_per_node))
module purge
echo "Job constraint: $SLURM_JOB_CONSTRAINT"
echo "Job partition: $SLURM_JOB_PARTITION"
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
# pour avoir acces aux modules compatibles avec cette partition
if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
module load cpuarch/amd
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
gpu_name=a100
else
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
gpu_name=v100
fi
# Chargement des modules
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus"
export EQX_ON_ERROR=nan
export CUDA_ALLOC=1
export ENABLE_PERFO_STEP=NVTX
export MPI4JAX_USE_CUDA_MPI=1
function profile_python() {
if [ $# -lt 1 ]; then
echo "Usage: profile_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="prof_traces/$script_name"
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="prof_traces/$script_name/$args"
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
mkdir -p "$report_dir"
srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
function run_python() {
if [ $# -lt 1 ]; then
echo "Usage: run_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="traces/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="traces/$script_name/$args"
fi
mkdir -p "$output_dir"
srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
# run or profile
function slaunch() {
run_python "$@"
}
function plaunch() {
profile_python "$@"
}
# Echo des commandes lancees
set -x
# Pour ne pas utiliser le /tmp
export TMPDIR=$JOBSCRATCH
# Pour contourner un bogue dans les versions actuelles de Nsight Systems
# il est également nécessaire de créer un lien symbolique permettant de
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
ln -s $JOBSCRATCH /tmp/nvidia
# mpch=(128 256 512 1024 2048 4096)
grid=(256 512 1024 2048 4096)
precisions=(float32 float64)
solvers=(lpt lfm)
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
if [ $num_gpu_per_node -eq 8 ]; then
gpu_name="a100"
else
gpu_name="v100"
fi
out_dir="pm_prof/$gpu_name/$nb_gpus"
echo "Output dir is : $out_dir"
for pr in "${precisions[@]}"; do
for g in "${grid[@]}"; do
for solver in "${solvers[@]}"; do
launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f
done
done
done

170
benchmarks/pmwd_v100.slurm Normal file
View file

@ -0,0 +1,170 @@
#!/bin/bash
##########################################
## SELECT EITHER tkc@a100 OR tkc@v100 ##
##########################################
#SBATCH --account tkc@v100
##########################################
#SBATCH --job-name=16N-V100Particle-Mesh # nom du job
# Il est possible d'utiliser une autre partition que celle par default
# en activant l'une des 5 directives suivantes :
##########################################
## SELECT EITHER a100 or v100-32g ##
##########################################
#SBATCH -C v100-32g
##########################################
#******************************************
##########################################
## SELECT Number of nodes and GPUs per node
## For A100 ntasks-per-node and gres=gpu should be 8
## For V100 ntasks-per-node and gres=gpu should be 4
##########################################
#SBATCH --nodes=1 # nombre de noeud
#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
#SBATCH --gres=gpu:1 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
##########################################
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
##########################################
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
##########################################
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
#SBATCH --hint=nomultithread # hyperthreading desactive
#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS)
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
#SBATCH --qos=qos_gpu-dev
#SBATCH --exclusive # ressources dediees
# Nettoyage des modules charges en interactif et herites par defaut
num_nodes=$SLURM_JOB_NUM_NODES
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
OUTPUT_FOLDER_ARGS=1
# Calculate the number of GPUs
nb_gpus=$(( num_nodes * num_gpu_per_node))
module purge
echo "Job constraint: $SLURM_JOB_CONSTRAINT"
echo "Job partition: $SLURM_JOB_PARTITION"
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
# pour avoir acces aux modules compatibles avec cette partition
if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
module load cpuarch/amd
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
gpu_name=a100
else
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
gpu_name=v100
fi
# Chargement des modules
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus"
export EQX_ON_ERROR=nan
export CUDA_ALLOC=1
export ENABLE_PERFO_STEP=NVTX
export MPI4JAX_USE_CUDA_MPI=1
function profile_python() {
if [ $# -lt 1 ]; then
echo "Usage: profile_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="prof_traces/$script_name"
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="prof_traces/$script_name/$args"
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
mkdir -p "$report_dir"
srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
function run_python() {
if [ $# -lt 1 ]; then
echo "Usage: run_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="traces/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="traces/$script_name/$args"
fi
mkdir -p "$output_dir"
srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
# run or profile
function slaunch() {
run_python "$@"
}
function plaunch() {
profile_python "$@"
}
# Echo des commandes lancees
set -x
# Pour ne pas utiliser le /tmp
export TMPDIR=$JOBSCRATCH
# Pour contourner un bogue dans les versions actuelles de Nsight Systems
# il est également nécessaire de créer un lien symbolique permettant de
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
ln -s $JOBSCRATCH /tmp/nvidia
# mpch=(128 256 512 1024 2048 4096)
grid=(256 512 1024 2048 4096)
precisions=(float32 float64)
pdim="${pdims_table[$nb_gpus]}"
solvers=(lpt lfm)
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
if [ $num_gpu_per_node -eq 8 ]; then
gpu_name="a100"
else
gpu_name="v100"
fi
out_dir="pm_prof/$gpu_name/$nb_gpus"
echo "Output dir is : $out_dir"
for pr in "${precisions[@]}"; do
for g in "${grid[@]}"; do
for solver in "${solvers[@]}"; do
slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f
done
done
done

19
benchmarks/run_all_jobs.sh Executable file
View file

@ -0,0 +1,19 @@
#!/bin/bash
# Run all slurms jobs
nodes_v100=(1 2 4 8 16)
nodes_a100=(1 2 4 8 16)
for n in ${nodes_v100[@]}; do
sbatch --nodes=$n --job-name=v100_$n-JAXPM particle_mesh_v100.slurm
done
for n in ${nodes_a100[@]}; do
sbatch --nodes=$n --job-name=a100_$n-JAXPM particle_mesh_a100.slurm
done
# single GPUs
sbatch --job-name=JAXPM-1GPU-V100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_v100.slurm
sbatch --job-name=JAXPM-1GPU-A100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_a100.slurm
sbatch --job-name=PMWD-v100 pmwd_v100.slurm
sbatch --job-name=PMWD-a100 pmwd_a100.slurm