update benchmark and add slurm

This commit is contained in:
Wassim KABALAN 2024-07-18 17:04:52 +02:00
parent ed8cf8e532
commit 0216837033
2 changed files with 226 additions and 24 deletions

View file

@ -15,8 +15,8 @@ import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt,
diffeqsolve)
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, Tsit5, diffeqsolve)
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding
@ -29,7 +29,8 @@ from jaxpm.pm import linear_field, lpt, make_ode_fn
def chrono_fun(fun, *args):
start = time.perf_counter()
out = fun(*args).block_until_ready()
out = fun(*args)
out[0].block_until_ready()
end = time.perf_counter()
return out, end - start
@ -59,18 +60,22 @@ def run_simulation(mesh_shape,
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
if solver_choice == "Dopri5":
solver = Dopri5()
elif solver_choice == "LeapfrogMidpoint":
solver = LeapfrogMidpoint()
elif solver_choice == "Tsit5":
solver = Tsit5()
elif solver_choice == "lpt":
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
return lpt_field, {"num_steps": 0, "Solver": "LPT"}
else:
raise ValueError(
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
res = diffeqsolve(term,
@ -93,17 +98,17 @@ def run_simulation(mesh_shape,
# Warm start
times = []
RangePush("warmup")
final_field, stats, warmup_time = chrono_fun(simulate, 0.32, 0.8)
(final_field, stats), warmup_time = chrono_fun(simulate, 0.32, 0.8)
RangePop()
sync_global_devices("warmup")
for i in range(iterations):
RangePush(f"sim iter {i}")
final_field, stats, sim_time = chrono_fun(simulate, 0.32, 0.8)
(final_field, stats), sim_time = chrono_fun(simulate, 0.32, 0.8)
RangePop()
times.append(sim_time)
return stats, warmup_time, times, final_field
if jax.device_count() > 1 and pdims:
if jax.device_count() > 1:
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
with mesh:
@ -134,7 +139,7 @@ if __name__ == "__main__":
type=str,
help='Processor dimensions',
default=None)
parser.add_argument('-h',
parser.add_argument('-hs',
'--halo_size',
type=int,
help='Halo size',
@ -143,7 +148,11 @@ if __name__ == "__main__":
'--solver',
type=str,
help='Solver',
choices=["Dopri5", "LeapfrogMidpoint"],
choices=[
"Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5",
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm",
"lpt"
],
required=True)
parser.add_argument('-i',
'--iterations',
@ -170,10 +179,24 @@ if __name__ == "__main__":
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
match solver_choice:
case "Dopri5", "dopri5", "d5":
solver_choice = "Dopri5"
case "Tsit5", "tsit5", "t5":
solver_choice = "Tsit5"
case "LeapfrogMidpoint", "leapfrogmidpoint", "lfm":
solver_choice = "LeapfrogMidpoint"
case "lpt":
solver_choice = "lpt"
case _:
raise ValueError(
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt"
)
if args.pdims:
pdims = tuple(map(int, args.pdims.split("x")))
else:
pdims = None
pdims = (1, 1)
mesh_shape = [mesh_size] * 3
@ -184,14 +207,6 @@ if __name__ == "__main__":
iterations,
pdims=pdims)
# Save the final field
if args.save_fields:
nb_gpus = jax.device_count()
field_folder = f"{output_path}/final_field/{nb_gpus}/{mesh_size}_{box_size[0]}/{solver_choice}/{halo_size}"
os.makedirs(field_folder, exist_ok=True)
np.save(f'{field_folder}/final_field_{rank}.npy',
final_field.addressable_data(0))
# Write benchmark results to CSV
# RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD
times = np.array(times)
@ -200,11 +215,31 @@ if __name__ == "__main__":
max_time = np.max(times) * 1000
mean_time = np.mean(times) * 1000
std_time = np.std(times) * 1000
with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f:
f.write(
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n"
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{stats['num_steps']},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n"
)
# Save the final field
nb_gpus = jax.device_count()
pdm_str = f"{pdims[0]}x{pdims[1]}"
field_folder = f"{output_path}/final_field/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/{halo_size}"
os.makedirs(field_folder, exist_ok=True)
with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n")
f.write(f"JIT time: {jit_in_ms:.4f} ms\n")
f.write(f"Min time: {min_time:.4f} ms\n")
f.write(f"Max time: {max_time:.4f} ms\n")
f.write(f"Mean time: {mean_time:.4f} ms\n")
f.write(f"Std time: {std_time:.4f} ms\n")
f.write(f"Stats: {stats}\n")
if args.save_fields:
np.save(f'{field_folder}/final_field_0_{rank}.npy',
final_field.addressable_data(0))
print(f"Finished! Warmup time: {warmup_time:.4f} seconds")
print(f"mean times: {np.mean(times):.4f}")
print(f"Stats")
print(f"Stats {stats}")
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
print(f"Saving field and logs in {field_folder}")

167
scripts/particle_mesh.slurm Normal file
View file

@ -0,0 +1,167 @@
#!/bin/bash
##########################################
## SELECT EITHER tkc@a100 OR tkc@v100 ##
##########################################
#SBATCH --account tkc@a100
##########################################
#SBATCH --job-name=Particle-Mesh # nom du job
# Il est possible d'utiliser une autre partition que celle par default
# en activant l'une des 5 directives suivantes :
##########################################
## SELECT EITHER a100 or v100-32g ##
##########################################
#SBATCH -C a100
##########################################
#******************************************
##########################################
## SELECT Number of nodes and GPUs per node
## For A100 ntasks-per-node and gres=gpu should be 8
## For V100 ntasks-per-node and gres=gpu should be 4
##########################################
#SBATCH --nodes=1 # nombre de noeud
#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
#SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
##########################################
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
##########################################
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
##########################################
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
#SBATCH --hint=nomultithread # hyperthreading desactive
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
#SBATCH --qos=qos_gpu-dev
#SBATCH --exclusive # ressources dediees
# Nettoyage des modules charges en interactif et herites par defaut
num_nodes=$SLURM_JOB_NUM_NODES
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
OUTPUT_FOLDER_ARGS=1
# Calculate the number of GPUs
nb_gpus=$(( num_nodes * num_gpu_per_node))
module purge
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
# pour avoir acces aux modules compatibles avec cette partition
if [ $num_gpu_per_node -eq 8 ]; then
module load cpuarch/amd
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
else
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
fi
# Chargement des modules
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus"
export EQX_ON_ERROR=nan
export CUDA_ALLOC=1
function profile_python() {
if [ $# -lt 1 ]; then
echo "Usage: profile_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="prof_traces/$script_name"
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="prof_traces/$script_name/$args"
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
mkdir -p "$report_dir"
srun nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
function run_python() {
if [ $# -lt 1 ]; then
echo "Usage: run_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="traces/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="traces/$script_name/$args"
fi
mkdir -p "$output_dir"
srun python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
# Echo des commandes lancees
set -x
# Pour la partition "gpu_p5", le code doit etre compile avec les modules compatibles
# Execution du code avec binding via bind_gpu.sh : 1 GPU par tache
declare -A pdims_table
# Define the table
pdims_table[4]="2x2 1x4"
pdims_table[8]="2x4 1x8"
pdims_table[16]="2x8 1x16"
pdims_table[32]="4x8 1x32"
pdims_table[64]="4x16 1x64"
pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2"
pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4"
#mpch=(128 256 512 1024 2048 4096)
grid=(1024 2048 4096)
pdim="${pdims_table[$nb_gpus]}"
echo "pdims: $pdim"
# Check if pdims is not empty
if [ -z "$pdim" ]; then
echo "pdims is empty"
echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160"
echo "Number of nodes selected: $num_nodes"
echo "Number of gpus per node: $num_gpu_per_node"
exit 1
fi
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
if [ $num_gpu_per_node -eq 8 ]; then
gpu_name="a100"
else
gpu_name="v100"
fi
out_dir="out/$gpu_name/$nb_gpus"
echo "Output dir is : $out_dir"
for g in ${grid[@]}; do
for p in ${pdim[@]}; do
# halo is 1/4 of the grid size
halo_size=$((g / 4))
slaunch scripts/fastpm_jaxdecomp.py -m $g -b $g -p $p -hs $halo_size -ode diffrax -o $out_dir
done
done