From 02168370338d049164f01d6d97d6eacf921f42a9 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 18 Jul 2024 17:04:52 +0200 Subject: [PATCH] update benchmark and add slurm --- scripts/bench_pm.py | 83 ++++++++++++------ scripts/particle_mesh.slurm | 167 ++++++++++++++++++++++++++++++++++++ 2 files changed, 226 insertions(+), 24 deletions(-) create mode 100644 scripts/particle_mesh.slurm diff --git a/scripts/bench_pm.py b/scripts/bench_pm.py index a572dad..e73a895 100644 --- a/scripts/bench_pm.py +++ b/scripts/bench_pm.py @@ -15,8 +15,8 @@ import jax.numpy as jnp import jax_cosmo as jc import numpy as np from cupy.cuda.nvtx import RangePop, RangePush -from diffrax import (Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt, - diffeqsolve) +from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, + PIDController, SaveAt, Tsit5, diffeqsolve) from jax.experimental import mesh_utils from jax.experimental.multihost_utils import sync_global_devices from jax.sharding import Mesh, NamedSharding @@ -29,7 +29,8 @@ from jaxpm.pm import linear_field, lpt, make_ode_fn def chrono_fun(fun, *args): start = time.perf_counter() - out = fun(*args).block_until_ready() + out = fun(*args) + out[0].block_until_ready() end = time.perf_counter() return out, end - start @@ -59,18 +60,22 @@ def run_simulation(mesh_shape, cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) - # Evolve the simulation forward - ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) - term = ODETerm( - lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) - if solver_choice == "Dopri5": solver = Dopri5() elif solver_choice == "LeapfrogMidpoint": solver = LeapfrogMidpoint() + elif solver_choice == "Tsit5": + solver = Tsit5() + elif solver_choice == "lpt": + lpt_field = cic_paint_dx(dx, halo_size=halo_size) + return lpt_field, {"num_steps": 0, "Solver": "LPT"} else: raise ValueError( "Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.") + # Evolve the simulation forward + ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) + term = ODETerm( + lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) res = diffeqsolve(term, @@ -93,17 +98,17 @@ def run_simulation(mesh_shape, # Warm start times = [] RangePush("warmup") - final_field, stats, warmup_time = chrono_fun(simulate, 0.32, 0.8) + (final_field, stats), warmup_time = chrono_fun(simulate, 0.32, 0.8) RangePop() sync_global_devices("warmup") for i in range(iterations): RangePush(f"sim iter {i}") - final_field, stats, sim_time = chrono_fun(simulate, 0.32, 0.8) + (final_field, stats), sim_time = chrono_fun(simulate, 0.32, 0.8) RangePop() times.append(sim_time) return stats, warmup_time, times, final_field - if jax.device_count() > 1 and pdims: + if jax.device_count() > 1: devices = mesh_utils.create_device_mesh(pdims) mesh = Mesh(devices.T, axis_names=('x', 'y')) with mesh: @@ -134,7 +139,7 @@ if __name__ == "__main__": type=str, help='Processor dimensions', default=None) - parser.add_argument('-h', + parser.add_argument('-hs', '--halo_size', type=int, help='Halo size', @@ -143,7 +148,11 @@ if __name__ == "__main__": '--solver', type=str, help='Solver', - choices=["Dopri5", "LeapfrogMidpoint"], + choices=[ + "Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5", + "LeapfrogMidpoint", "leapfrogmidpoint", "lfm", + "lpt" + ], required=True) parser.add_argument('-i', '--iterations', @@ -170,10 +179,24 @@ if __name__ == "__main__": output_path = args.output_path os.makedirs(output_path, exist_ok=True) + match solver_choice: + case "Dopri5", "dopri5", "d5": + solver_choice = "Dopri5" + case "Tsit5", "tsit5", "t5": + solver_choice = "Tsit5" + case "LeapfrogMidpoint", "leapfrogmidpoint", "lfm": + solver_choice = "LeapfrogMidpoint" + case "lpt": + solver_choice = "lpt" + case _: + raise ValueError( + "Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt" + ) + if args.pdims: pdims = tuple(map(int, args.pdims.split("x"))) else: - pdims = None + pdims = (1, 1) mesh_shape = [mesh_size] * 3 @@ -184,14 +207,6 @@ if __name__ == "__main__": iterations, pdims=pdims) - # Save the final field - if args.save_fields: - nb_gpus = jax.device_count() - field_folder = f"{output_path}/final_field/{nb_gpus}/{mesh_size}_{box_size[0]}/{solver_choice}/{halo_size}" - os.makedirs(field_folder, exist_ok=True) - np.save(f'{field_folder}/final_field_{rank}.npy', - final_field.addressable_data(0)) - # Write benchmark results to CSV # RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD times = np.array(times) @@ -200,11 +215,31 @@ if __name__ == "__main__": max_time = np.max(times) * 1000 mean_time = np.mean(times) * 1000 std_time = np.std(times) * 1000 + with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f: f.write( - f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n" + f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{stats['num_steps']},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n" ) + # Save the final field + nb_gpus = jax.device_count() + pdm_str = f"{pdims[0]}x{pdims[1]}" + field_folder = f"{output_path}/final_field/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/{halo_size}" + os.makedirs(field_folder, exist_ok=True) + with open(f'{field_folder}/jaxpm.log', 'w') as f: + f.write(f"Args: {args}\n") + f.write(f"JIT time: {jit_in_ms:.4f} ms\n") + f.write(f"Min time: {min_time:.4f} ms\n") + f.write(f"Max time: {max_time:.4f} ms\n") + f.write(f"Mean time: {mean_time:.4f} ms\n") + f.write(f"Std time: {std_time:.4f} ms\n") + f.write(f"Stats: {stats}\n") + if args.save_fields: + np.save(f'{field_folder}/final_field_0_{rank}.npy', + final_field.addressable_data(0)) + print(f"Finished! Warmup time: {warmup_time:.4f} seconds") print(f"mean times: {np.mean(times):.4f}") - print(f"Stats") + print(f"Stats {stats}") + print(f"Saving to {output_path}/jax_pm_benchmark.csv") + print(f"Saving field and logs in {field_folder}") diff --git a/scripts/particle_mesh.slurm b/scripts/particle_mesh.slurm new file mode 100644 index 0000000..51332a5 --- /dev/null +++ b/scripts/particle_mesh.slurm @@ -0,0 +1,167 @@ +#!/bin/bash +########################################## +## SELECT EITHER tkc@a100 OR tkc@v100 ## +########################################## +#SBATCH --account tkc@a100 +########################################## +#SBATCH --job-name=Particle-Mesh # nom du job +# Il est possible d'utiliser une autre partition que celle par default +# en activant l'une des 5 directives suivantes : +########################################## +## SELECT EITHER a100 or v100-32g ## +########################################## +#SBATCH -C a100 +########################################## +#****************************************** +########################################## +## SELECT Number of nodes and GPUs per node +## For A100 ntasks-per-node and gres=gpu should be 8 +## For V100 ntasks-per-node and gres=gpu should be 4 +########################################## +#SBATCH --nodes=1 # nombre de noeud +#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud) +#SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5) +########################################## +## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant +## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant +## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: +########################################## +#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) +########################################## +# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm +#SBATCH --hint=nomultithread # hyperthreading desactive +#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) +#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie +#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) +#SBATCH --qos=qos_gpu-dev +#SBATCH --exclusive # ressources dediees + +# Nettoyage des modules charges en interactif et herites par defaut +num_nodes=$SLURM_JOB_NUM_NODES +num_gpu_per_node=$SLURM_NTASKS_PER_NODE +OUTPUT_FOLDER_ARGS=1 +# Calculate the number of GPUs +nb_gpus=$(( num_nodes * num_gpu_per_node)) + +module purge + +# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" +# pour avoir acces aux modules compatibles avec cette partition + +if [ $num_gpu_per_node -eq 8 ]; then + module load cpuarch/amd + source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate +else + source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate +fi + +# Chargement des modules +module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake +module load nvidia-nsight-systems/2024.1.1.59 + +echo "The number of nodes allocated for this job is: $num_nodes" +echo "The number of GPUs allocated for this job is: $nb_gpus" + +export EQX_ON_ERROR=nan +export CUDA_ALLOC=1 + +function profile_python() { + if [ $# -lt 1 ]; then + echo "Usage: profile_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="prof_traces/$script_name" + local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="prof_traces/$script_name/$args" + report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" + fi + + mkdir -p "$output_dir" + mkdir -p "$report_dir" + + srun nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + +function run_python() { + if [ $# -lt 1 ]; then + echo "Usage: run_python [arguments for the script]" + return 1 + fi + + local script_name=$(basename "$1" .py) + local output_dir="traces/$script_name" + + if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then + local args=$(echo "${@:2}" | tr ' ' '_') + # Remove characters '/' and '-' from folder name + args=$(echo "$args" | tr -d '/-') + output_dir="traces/$script_name/$args" + fi + + mkdir -p "$output_dir" + + srun python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true +} + + +# Echo des commandes lancees +set -x + +# Pour la partition "gpu_p5", le code doit etre compile avec les modules compatibles +# Execution du code avec binding via bind_gpu.sh : 1 GPU par tache + + + + +declare -A pdims_table +# Define the table +pdims_table[4]="2x2 1x4" +pdims_table[8]="2x4 1x8" +pdims_table[16]="2x8 1x16" +pdims_table[32]="4x8 1x32" +pdims_table[64]="4x16 1x64" +pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2" +pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4" + + +#mpch=(128 256 512 1024 2048 4096) +grid=(1024 2048 4096) + +pdim="${pdims_table[$nb_gpus]}" +echo "pdims: $pdim" + +# Check if pdims is not empty +if [ -z "$pdim" ]; then + echo "pdims is empty" + echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160" + echo "Number of nodes selected: $num_nodes" + echo "Number of gpus per node: $num_gpu_per_node" + exit 1 +fi + +# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 + +if [ $num_gpu_per_node -eq 8 ]; then + gpu_name="a100" +else + gpu_name="v100" +fi + +out_dir="out/$gpu_name/$nb_gpus" + +echo "Output dir is : $out_dir" + +for g in ${grid[@]}; do + for p in ${pdim[@]}; do + # halo is 1/4 of the grid size + halo_size=$((g / 4)) + slaunch scripts/fastpm_jaxdecomp.py -m $g -b $g -p $p -hs $halo_size -ode diffrax -o $out_dir + done +done