From 2ea05a1cd6a451dec406b222b69928c5cf6b6956 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 7 Aug 2024 23:52:13 +0200 Subject: [PATCH] Update for jaxDecomp pure JAX --- benchmarks/bench_pm.py | 187 ++++++++++++------ ...le_mesh_a100.slurm => particle_mesh.slurm} | 72 +++---- benchmarks/particle_mesh_v100.slurm | 184 ----------------- benchmarks/pmwd_a100.slurm | 165 ---------------- benchmarks/{pmwd_v100.slurm => pmwd_pm.slurm} | 73 +++---- benchmarks/run_all_jobs.sh | 18 +- jaxpm/distributed.py | 4 +- jaxpm/painting.py | 38 ++-- scripts/particle_mesh.slurm | 5 +- 9 files changed, 214 insertions(+), 532 deletions(-) rename benchmarks/{particle_mesh_a100.slurm => particle_mesh.slurm} (65%) delete mode 100644 benchmarks/particle_mesh_v100.slurm delete mode 100644 benchmarks/pmwd_a100.slurm rename benchmarks/{pmwd_v100.slurm => pmwd_pm.slurm} (64%) diff --git a/benchmarks/bench_pm.py b/benchmarks/bench_pm.py index 9b916f9..c80fdd5 100644 --- a/benchmarks/bench_pm.py +++ b/benchmarks/bench_pm.py @@ -25,7 +25,7 @@ from jax.sharding import PartitionSpec as P from jaxpm.kernels import interpolate_power_spectrum from jaxpm.painting import cic_paint_dx from jaxpm.pm import linear_field, lpt, make_ode_fn - +from jax import make_jaxpr def run_simulation(mesh_shape, @@ -33,7 +33,10 @@ def run_simulation(mesh_shape, halo_size, solver_choice, iterations, - pdims=None): + hlo_print, + trace, + pdims=None, + output_path="."): @jax.jit def simulate(omega_c, sigma8): @@ -60,7 +63,6 @@ def run_simulation(mesh_shape, solver = Tsit5() elif solver_choice == "lpt": lpt_field = cic_paint_dx(dx, halo_size=halo_size) - print(f"TYPE of lpt_field: {type(lpt_field)}") return lpt_field, {"num_steps": 0} else: raise ValueError( @@ -92,16 +94,37 @@ def run_simulation(mesh_shape, def run(): # Warm start - chrono_fun = Timer() - RangePush("warmup") - final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0) - RangePop() - sync_global_devices("warmup") - for i in range(iterations): - RangePush(f"sim iter {i}") - final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0) - RangePop() - return final_field, stats, chrono_fun + if hlo_print: + jaxpr = make_jaxpr(simulate)(0.32, 0.8) + lowered = jax.jit(simulate).lower(0.32, 0.8) + lower_as_text = lowered.as_text() + compiled = lowered.compile() + compiled_again = jax.jit(simulate).lower(0.32, 0.8).compile() + return jaxpr , compiled , compiled_again + elif trace: + jit_output = f"{output_path}/jit_trace" + first_run_output = f"{output_path}/first_run_trace" + second_run_output = f"{output_path}/second_run_trace" + with jax.profiler.trace(jit_output , create_perfetto_trace=True): + final_field, stats = simulate(0.32, 0.8) + final_field.block_until_ready() + with jax.profiler.trace(first_run_output , create_perfetto_trace=True): + final_field, stats = simulate(0.32, 0.8) + final_field.block_until_ready() + with jax.profiler.trace(second_run_output , create_perfetto_trace=True): + final_field, stats = simulate(0.32, 0.8) + final_field.block_until_ready() + else: + chrono_fun = Timer() + RangePush("warmup") + final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0) + RangePop() + sync_global_devices("warmup") + for i in range(iterations): + RangePush(f"sim iter {i}") + final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0) + RangePop() + return final_field, stats, chrono_fun if jax.device_count() > 1: devices = mesh_utils.create_device_mesh(pdims) @@ -143,7 +166,7 @@ if __name__ == "__main__": '--halo_size', type=int, help='Halo size', - required=True) + default=None) parser.add_argument('-s', '--solver', type=str, @@ -153,12 +176,7 @@ if __name__ == "__main__": "LeapfrogMidpoint", "leapfrogmidpoint", "lfm", "lpt" ], - required=True) - parser.add_argument('-i', - '--iterations', - type=int, - help='Number of iterations', - default=10) + default="lpt") parser.add_argument('-o', '--output_path', type=str, @@ -173,16 +191,33 @@ if __name__ == "__main__": type=int, help='Number of nodes', default=1) + parser.add_argument('-i', + '--iterations', + type=int, + help='Number of iterations', + default=10) + group = parser.add_mutually_exclusive_group() + group.add_argument('-hlo', + '--hlo_print', + action='store_true', + help='Print hlo generated by XLA') + group.add_argument('-t', + '--trace', + action='store_true', + help='Profile using tensorboard') args = parser.parse_args() mesh_size = args.mesh_size box_size = [args.box_size] * 3 - halo_size = args.halo_size + halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size solver_choice = args.solver iterations = args.iterations output_path = args.output_path os.makedirs(output_path, exist_ok=True) - + hlo_print = args.hlo_print + trace = args.trace + nb_gpus = jax.device_count() + print(f"solver choice: {solver_choice}") match solver_choice: case "Dopri5" | "dopri5"| "d5": @@ -205,45 +240,81 @@ if __name__ == "__main__": if args.pdims: pdims = tuple(map(int, args.pdims.split("x"))) else: - pdims = (1, 1) + pdims = (1, jax.device_count()) + pdm_str = f"{pdims[0]}x{pdims[1]}" mesh_shape = [mesh_size] * 3 - - final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims) - print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}") + if trace: + trace_folder = f"{output_path}/profiling/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" + os.makedirs(trace_folder, exist_ok=True) + run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, trace_folder) + print(f"Profiling done! Check {trace_folder}") + elif hlo_print: + hlo_folder = f"{output_path}/hlo/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" + os.makedirs(hlo_folder, exist_ok=True) + jaxpr , compiled , compiled2 = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, hlo_folder) + print(f"type of memory analysis {type(compiled.memory_analysis())}") + print(f"memory analysis {compiled.memory_analysis()}") + print(f"memory analysis again {compiled2.memory_analysis()}") + jax.tree.map(lambda x: print(x), compiled.memory_analysis()) + with open(f'{hlo_folder}/hlo_jaxpm.md', 'w') as f: + f.write(f"# JAXPM HLO\n") + f.write(f"## Args: {args}\n") + f.write(f"## JAXPR is \n") + f.write(f'---\n') + f.write(f"{jaxpr}\n") + f.write(f'---\n') + f.write(f"Lowered as text is \n") + f.write(f'---\n') + # f.write(f"{lower_as_text}\n") + f.write(f'---\n') + f.write(f"Compiled is \n") + f.write(f'---\n') + f.write(f"{compiled.as_text()}\n") + f.write(f"Cost analysis is \n") + f.write(f'---\n') + f.write(f"{compiled.cost_analysis()[0]['flops']}\n") + f.write(f'---\n') + f.write(f"Memory analysis is \n") + f.write(f'---\n') + f.write(f"{compiled.memory_analysis()}\n") + f.write(f'---\n') - metadata = { - 'rank': rank, - 'function_name': f'JAXPM-{solver_choice}', - 'precision': args.precision, - 'x': str(mesh_size), - 'y': str(mesh_size), - 'z': str(stats["num_steps"]), - 'px': str(pdims[0]), - 'py': str(pdims[1]), - 'backend': 'NCCL', - 'nodes': str(args.nodes) - } - # Print the results to a CSV file - chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) + print(f"Saved HLO to {hlo_folder}") + else: + final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, output_path) + print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}") + metadata = { + 'rank': rank, + 'function_name': f'JAXPM-{solver_choice}', + 'precision': args.precision, + 'x': str(mesh_size), + 'y': str(mesh_size), + 'z': str(stats["num_steps"]), + 'px': str(pdims[0]), + 'py': str(pdims[1]), + 'backend': 'NCCL', + 'nodes': str(args.nodes) + } + # Print the results to a CSV file + chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) - # Save the final field - nb_gpus = jax.device_count() - pdm_str = f"{pdims[0]}x{pdims[1]}" - field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" - os.makedirs(field_folder, exist_ok=True) - with open(f'{field_folder}/jaxpm.log', 'w') as f: - f.write(f"Args: {args}\n") - f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") - for i , time in enumerate(chrono_fun.times): - f.write(f"Time {i}: {time:.4f} ms\n") - f.write(f"Stats: {stats}\n") - if args.save_fields: - np.save(f'{field_folder}/final_field_0_{rank}.npy', - final_field.addressable_data(0)) + # Save the final field - print(f"Finished! ") - print(f"Stats {stats}") - print(f"Saving to {output_path}/jax_pm_benchmark.csv") - print(f"Saving field and logs in {field_folder}") + field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" + os.makedirs(field_folder, exist_ok=True) + with open(f'{field_folder}/jaxpm.log', 'w') as f: + f.write(f"Args: {args}\n") + f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") + for i , time in enumerate(chrono_fun.times): + f.write(f"Time {i}: {time:.4f} ms\n") + f.write(f"Stats: {stats}\n") + if args.save_fields: + np.save(f'{field_folder}/final_field_0_{rank}.npy', + final_field.addressable_data(0)) + + print(f"Finished! ") + print(f"Stats {stats}") + print(f"Saving to {output_path}/jax_pm_benchmark.csv") + print(f"Saving field and logs in {field_folder}") diff --git a/benchmarks/particle_mesh_a100.slurm b/benchmarks/particle_mesh.slurm similarity index 65% rename from benchmarks/particle_mesh_a100.slurm rename to benchmarks/particle_mesh.slurm index 65930b2..ab55b05 100644 --- a/benchmarks/particle_mesh_a100.slurm +++ b/benchmarks/particle_mesh.slurm @@ -1,40 +1,15 @@ #!/bin/bash -########################################## -## SELECT EITHER tkc@a100 OR tkc@v100 ## -########################################## -#SBATCH --account tkc@a100 -########################################## -#SBATCH --job-name=1N-FFT-Mesh # nom du job -# Il est possible d'utiliser une autre partition que celle par default -# en activant l'une des 5 directives suivantes : -########################################## -## SELECT EITHER a100 or v100-32g ## -########################################## -#SBATCH -C a100 -########################################## -#****************************************** -########################################## -## SELECT Number of nodes and GPUs per node -## For A100 ntasks-per-node and gres=gpu should be 8 -## For V100 ntasks-per-node and gres=gpu should be 4 -########################################## -#SBATCH --nodes=1 # nombre de noeud -#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud) -#SBATCH --gres=gpu:8 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) -########################################## -## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant -## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant -## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: -########################################## +############################################################################################################################## +# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm +############################################################################################################################## +#SBATCH --job-name=Particle-Mesh # nom du job #SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) -########################################## -# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm #SBATCH --hint=nomultithread # hyperthreading desactive #SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) #SBATCH --output=%x_%N_a100.out # nom du fichier de sortie -#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) +#SBATCH --error=%x_%N_a100.err # nom du fichier d'erreur (ici commun avec la sortie) +#SBATCH --exclusive # ressources dediees ##SBATCH --qos=qos_gpu-dev -## SBATCH --exclusive # ressources dediees # Nettoyage des modules charges en interactif et herites par defaut num_nodes=$SLURM_JOB_NUM_NODES num_gpu_per_node=$SLURM_NTASKS_PER_NODE @@ -44,12 +19,11 @@ nb_gpus=$(( num_nodes * num_gpu_per_node)) module purge -echo "Job constraint: $SLURM_JOB_CONSTRAINT" echo "Job partition: $SLURM_JOB_PARTITION" # Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" # pour avoir acces aux modules compatibles avec cette partition -if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then +if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then module load cpuarch/amd source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate gpu_name=a100 @@ -66,8 +40,10 @@ module load nvidia-nsight-systems/2024.1.1.59 echo "The number of nodes allocated for this job is: $num_nodes" echo "The number of GPUs allocated for this job is: $nb_gpus" +export EQX_ON_ERROR=nan export ENABLE_PERFO_STEP=NVTX export MPI4JAX_USE_CUDA_MPI=1 + function profile_python() { if [ $# -lt 1 ]; then echo "Usage: profile_python [arguments for the script]" @@ -75,14 +51,14 @@ function profile_python() { fi local script_name=$(basename "$1" .py) - local output_dir="prof_traces/$script_name" + local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name" local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then local args=$(echo "${@:2}" | tr ' ' '_') # Remove characters '/' and '-' from folder name args=$(echo "$args" | tr -d '/-') - output_dir="prof_traces/$script_name/$args" + output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args" report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" fi @@ -99,13 +75,13 @@ function run_python() { fi local script_name=$(basename "$1" .py) - local output_dir="traces/$script_name" + local output_dir="traces/$gpu_name/$nb_gpus/$script_name" if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then local args=$(echo "${@:2}" | tr ' ' '_') # Remove characters '/' and '-' from folder name args=$(echo "$args" | tr -d '/-') - output_dir="traces/$script_name/$args" + output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args" fi mkdir -p "$output_dir" @@ -142,12 +118,11 @@ pdims_table[8]="2x4 1x8 8x1 4x2" pdims_table[16]="4x4 1x16 16x1" pdims_table[32]="4x8 8x4 1x32 32x1" pdims_table[64]="8x8 16x4 1x64 64x1" -pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2" -pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4" - +pdims_table[128]="8x16 16x8 1x128 128x1" +pdims_table[256]="16x16 1x256 256x1" # mpch=(128 256 512 1024 2048 4096) -grid=(256 512 1024 2048 4096) +grid=(256 512 1024 2048 4096 8192) precisions=(float32 float64) pdim="${pdims_table[$nb_gpus]}" solvers=(lpt lfm) @@ -164,8 +139,9 @@ fi # GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 out_dir="pm_prof/$gpu_name/$nb_gpus" - -echo "Output dir is : $out_dir" +trace_dir="traces/$gpu_name/$nb_gpus/bench_pm" +echo "Output dir is : $out_dir" +echo "Trace dir is : $trace_dir" for pr in "${precisions[@]}"; do for g in "${grid[@]}"; do @@ -175,9 +151,13 @@ for pr in "${precisions[@]}"; do slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes done done + # delete crash core dump files + rm -f core.python.* done done - - - +# # zip the output files and traces +# tar -czvf $out_dir.tar.gz $out_dir +# tar -czvf $trace_dir.tar.gz $trace_dir +# # remove the output files and traces +# rm -rf $out_dir $trace_dir diff --git a/benchmarks/particle_mesh_v100.slurm b/benchmarks/particle_mesh_v100.slurm deleted file mode 100644 index 9eeb610..0000000 --- a/benchmarks/particle_mesh_v100.slurm +++ /dev/null @@ -1,184 +0,0 @@ -#!/bin/bash -########################################## -## SELECT EITHER tkc@a100 OR tkc@v100 ## -########################################## -#SBATCH --account tkc@v100 -########################################## -#SBATCH --job-name=V100Particle-Mesh # nom du job -# Il est possible d'utiliser une autre partition que celle par default -# en activant l'une des 5 directives suivantes : -########################################## -## SELECT EITHER a100 or v100-32g ## -########################################## -#SBATCH -C v100-32g -########################################## -#****************************************** -########################################## -## SELECT Number of nodes and GPUs per node -## For A100 ntasks-per-node and gres=gpu should be 8 -## For V100 ntasks-per-node and gres=gpu should be 4 -########################################## -#SBATCH --nodes=1 # nombre de noeud -#SBATCH --ntasks-per-node=4 # nombre de tache MPI par noeud (= nombre de GPU par noeud) -#SBATCH --gres=gpu:4 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) -########################################## -## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant -## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant -## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: -########################################## -#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) -########################################## -# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm -#SBATCH --hint=nomultithread # hyperthreading desactive -#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS) -#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie -#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) -#SBATCH --qos=qos_gpu-dev -#SBATCH --exclusive # ressources dediees -# Nettoyage des modules charges en interactif et herites par defaut -num_nodes=$SLURM_JOB_NUM_NODES -num_gpu_per_node=$SLURM_NTASKS_PER_NODE -OUTPUT_FOLDER_ARGS=1 -# Calculate the number of GPUs -nb_gpus=$(( num_nodes * num_gpu_per_node)) - -module purge - -echo "Job constraint: $SLURM_JOB_CONSTRAINT" -echo "Job partition: $SLURM_JOB_PARTITION" -# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" -# pour avoir acces aux modules compatibles avec cette partition - -if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then - module load cpuarch/amd - source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate - gpu_name=a100 -else - source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate - gpu_name=v100 -fi - -# Chargement des modules -module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake -module load nvidia-nsight-systems/2024.1.1.59 - - -echo "The number of nodes allocated for this job is: $num_nodes" -echo "The number of GPUs allocated for this job is: $nb_gpus" - -export EQX_ON_ERROR=nan -export CUDA_ALLOC=1 -export ENABLE_PERFO_STEP=NVTX -export MPI4JAX_USE_CUDA_MPI=1 -function profile_python() { - if [ $# -lt 1 ]; then - echo "Usage: profile_python [arguments for the script]" - return 1 - fi - - local script_name=$(basename "$1" .py) - local output_dir="prof_traces/$script_name" - local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" - - if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then - local args=$(echo "${@:2}" | tr ' ' '_') - # Remove characters '/' and '-' from folder name - args=$(echo "$args" | tr -d '/-') - output_dir="prof_traces/$script_name/$args" - report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" - fi - - mkdir -p "$output_dir" - mkdir -p "$report_dir" - - srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true -} - -function run_python() { - if [ $# -lt 1 ]; then - echo "Usage: run_python [arguments for the script]" - return 1 - fi - - local script_name=$(basename "$1" .py) - local output_dir="traces/$script_name" - - if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then - local args=$(echo "${@:2}" | tr ' ' '_') - # Remove characters '/' and '-' from folder name - args=$(echo "$args" | tr -d '/-') - output_dir="traces/$script_name/$args" - fi - - mkdir -p "$output_dir" - - srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true -} - - -# run or profile - -function slaunch() { - run_python "$@" -} - -function plaunch() { - profile_python "$@" -} - -# Echo des commandes lancees -set -x - -# Pour ne pas utiliser le /tmp -export TMPDIR=$JOBSCRATCH -# Pour contourner un bogue dans les versions actuelles de Nsight Systems -# il est également nécessaire de créer un lien symbolique permettant de -# faire pointer le répertoire /tmp/nvidia vers TMPDIR -ln -s $JOBSCRATCH /tmp/nvidia - -declare -A pdims_table -# Define the table -pdims_table[1]="1x1" -pdims_table[4]="2x2 1x4 4x1" -pdims_table[8]="2x4 1x8 8x1 4x2" -pdims_table[16]="4x4 1x16 16x1" -pdims_table[32]="4x8 8x4 1x32 32x1" -pdims_table[64]="8x8 16x4 1x64 64x1" -pdims_table[128]="8x16 16x8 4x32 1x128 128x1" -pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4" - - -# mpch=(128 256 512 1024 2048 4096) -grid=(256 512 1024 2048 4096) -precisions=(float32 float64) -pdim="${pdims_table[$nb_gpus]}" -solvers=(lpt lfm) -echo "pdims: $pdim" - -# Check if pdims is not empty -if [ -z "$pdim" ]; then - echo "pdims is empty" - echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160" - echo "Number of nodes selected: $num_nodes" - echo "Number of gpus per node: $num_gpu_per_node" - exit 1 -fi - -# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 -out_dir="pm_prof/$gpu_name/$nb_gpus" - -echo "Output dir is : $out_dir" - -for pr in "${precisions[@]}"; do - for g in "${grid[@]}"; do - for solver in "${solvers[@]}"; do - for p in $pdim; do - halo_size=$((g / 4)) - slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes - done - done - done -done - - - diff --git a/benchmarks/pmwd_a100.slurm b/benchmarks/pmwd_a100.slurm deleted file mode 100644 index 99c64c7..0000000 --- a/benchmarks/pmwd_a100.slurm +++ /dev/null @@ -1,165 +0,0 @@ -#!/bin/bash -########################################## -## SELECT EITHER tkc@a100 OR tkc@v100 ## -########################################## -#SBATCH --account tkc@a100 -########################################## -#SBATCH --job-name=1N-FFT-Mesh # nom du job -# Il est possible d'utiliser une autre partition que celle par default -# en activant l'une des 5 directives suivantes : -########################################## -## SELECT EITHER a100 or v100-32g ## -########################################## -#SBATCH -C a100 -########################################## -#****************************************** -########################################## -## SELECT Number of nodes and GPUs per node -## For A100 ntasks-per-node and gres=gpu should be 8 -## For V100 ntasks-per-node and gres=gpu should be 4 -########################################## -#SBATCH --nodes=1 # nombre de noeud -#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud) -#SBATCH --gres=gpu:1 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) -########################################## -## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant -## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant -## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: -########################################## -#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) -########################################## -# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm -#SBATCH --hint=nomultithread # hyperthreading desactive -#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) -#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie -#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) -##SBATCH --qos=qos_gpu-dev -## SBATCH --exclusive # ressources dediees -# Nettoyage des modules charges en interactif et herites par defaut -num_nodes=$SLURM_JOB_NUM_NODES -num_gpu_per_node=$SLURM_NTASKS_PER_NODE -OUTPUT_FOLDER_ARGS=1 -# Calculate the number of GPUs -nb_gpus=$(( num_nodes * num_gpu_per_node)) - -module purge - -echo "Job constraint: $SLURM_JOB_CONSTRAINT" -echo "Job partition: $SLURM_JOB_PARTITION" -# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" -# pour avoir acces aux modules compatibles avec cette partition - -if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then - module load cpuarch/amd - source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate - gpu_name=a100 -else - source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate - gpu_name=v100 -fi - -# Chargement des modules -module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake -module load nvidia-nsight-systems/2024.1.1.59 - - -echo "The number of nodes allocated for this job is: $num_nodes" -echo "The number of GPUs allocated for this job is: $nb_gpus" - -export EQX_ON_ERROR=nan -export CUDA_ALLOC=1 -export ENABLE_PERFO_STEP=NVTX -export MPI4JAX_USE_CUDA_MPI=1 -function profile_python() { - if [ $# -lt 1 ]; then - echo "Usage: profile_python [arguments for the script]" - return 1 - fi - - local script_name=$(basename "$1" .py) - local output_dir="prof_traces/$script_name" - local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" - - if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then - local args=$(echo "${@:2}" | tr ' ' '_') - # Remove characters '/' and '-' from folder name - args=$(echo "$args" | tr -d '/-') - output_dir="prof_traces/$script_name/$args" - report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" - fi - - mkdir -p "$output_dir" - mkdir -p "$report_dir" - - srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true -} - -function run_python() { - if [ $# -lt 1 ]; then - echo "Usage: run_python [arguments for the script]" - return 1 - fi - - local script_name=$(basename "$1" .py) - local output_dir="traces/$script_name" - - if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then - local args=$(echo "${@:2}" | tr ' ' '_') - # Remove characters '/' and '-' from folder name - args=$(echo "$args" | tr -d '/-') - output_dir="traces/$script_name/$args" - fi - - mkdir -p "$output_dir" - - srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true -} - - -# run or profile - -function slaunch() { - run_python "$@" -} - -function plaunch() { - profile_python "$@" -} - -# Echo des commandes lancees -set -x - -# Pour ne pas utiliser le /tmp -export TMPDIR=$JOBSCRATCH -# Pour contourner un bogue dans les versions actuelles de Nsight Systems -# il est également nécessaire de créer un lien symbolique permettant de -# faire pointer le répertoire /tmp/nvidia vers TMPDIR -ln -s $JOBSCRATCH /tmp/nvidia - -# mpch=(128 256 512 1024 2048 4096) -grid=(256 512 1024 2048 4096) -precisions=(float32 float64) -solvers=(lpt lfm) - -# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 - -if [ $num_gpu_per_node -eq 8 ]; then - gpu_name="a100" -else - gpu_name="v100" -fi - -out_dir="pm_prof/$gpu_name/$nb_gpus" - -echo "Output dir is : $out_dir" - -for pr in "${precisions[@]}"; do - for g in "${grid[@]}"; do - for solver in "${solvers[@]}"; do - launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f - done - done -done - - - diff --git a/benchmarks/pmwd_v100.slurm b/benchmarks/pmwd_pm.slurm similarity index 64% rename from benchmarks/pmwd_v100.slurm rename to benchmarks/pmwd_pm.slurm index 4c58db5..3e2cee8 100644 --- a/benchmarks/pmwd_v100.slurm +++ b/benchmarks/pmwd_pm.slurm @@ -1,40 +1,15 @@ #!/bin/bash -########################################## -## SELECT EITHER tkc@a100 OR tkc@v100 ## -########################################## -#SBATCH --account tkc@v100 -########################################## -#SBATCH --job-name=16N-V100Particle-Mesh # nom du job -# Il est possible d'utiliser une autre partition que celle par default -# en activant l'une des 5 directives suivantes : -########################################## -## SELECT EITHER a100 or v100-32g ## -########################################## -#SBATCH -C v100-32g -########################################## -#****************************************** -########################################## -## SELECT Number of nodes and GPUs per node -## For A100 ntasks-per-node and gres=gpu should be 8 -## For V100 ntasks-per-node and gres=gpu should be 4 -########################################## -#SBATCH --nodes=1 # nombre de noeud -#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud) -#SBATCH --gres=gpu:1 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) -########################################## -## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant -## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant -## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: -########################################## +############################################################################################################################## +# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm +############################################################################################################################## +#SBATCH --job-name=Particle-Mesh # nom du job #SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) -########################################## -# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm #SBATCH --hint=nomultithread # hyperthreading desactive -#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS) +#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) #SBATCH --output=%x_%N_a100.out # nom du fichier de sortie #SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) -#SBATCH --qos=qos_gpu-dev #SBATCH --exclusive # ressources dediees +##SBATCH --qos=qos_gpu-dev # Nettoyage des modules charges en interactif et herites par defaut num_nodes=$SLURM_JOB_NUM_NODES num_gpu_per_node=$SLURM_NTASKS_PER_NODE @@ -44,12 +19,11 @@ nb_gpus=$(( num_nodes * num_gpu_per_node)) module purge -echo "Job constraint: $SLURM_JOB_CONSTRAINT" echo "Job partition: $SLURM_JOB_PARTITION" # Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" # pour avoir acces aux modules compatibles avec cette partition -if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then +if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then module load cpuarch/amd source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate gpu_name=a100 @@ -67,9 +41,9 @@ echo "The number of nodes allocated for this job is: $num_nodes" echo "The number of GPUs allocated for this job is: $nb_gpus" export EQX_ON_ERROR=nan -export CUDA_ALLOC=1 export ENABLE_PERFO_STEP=NVTX export MPI4JAX_USE_CUDA_MPI=1 + function profile_python() { if [ $# -lt 1 ]; then echo "Usage: profile_python [arguments for the script]" @@ -77,14 +51,14 @@ function profile_python() { fi local script_name=$(basename "$1" .py) - local output_dir="prof_traces/$script_name" + local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name" local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then local args=$(echo "${@:2}" | tr ' ' '_') # Remove characters '/' and '-' from folder name args=$(echo "$args" | tr -d '/-') - output_dir="prof_traces/$script_name/$args" + output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args" report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" fi @@ -101,13 +75,13 @@ function run_python() { fi local script_name=$(basename "$1" .py) - local output_dir="traces/$script_name" + local output_dir="traces/$gpu_name/$nb_gpus/$script_name" if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then local args=$(echo "${@:2}" | tr ' ' '_') # Remove characters '/' and '-' from folder name args=$(echo "$args" | tr -d '/-') - output_dir="traces/$script_name/$args" + output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args" fi mkdir -p "$output_dir" @@ -116,6 +90,7 @@ function run_python() { } + # run or profile function slaunch() { @@ -136,13 +111,9 @@ export TMPDIR=$JOBSCRATCH # faire pointer le répertoire /tmp/nvidia vers TMPDIR ln -s $JOBSCRATCH /tmp/nvidia - - - # mpch=(128 256 512 1024 2048 4096) -grid=(256 512 1024 2048 4096) +grid=(256 512 1024 2048 4096 8192) precisions=(float32 float64) -pdim="${pdims_table[$nb_gpus]}" solvers=(lpt lfm) # GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 @@ -154,17 +125,23 @@ else fi out_dir="pm_prof/$gpu_name/$nb_gpus" - -echo "Output dir is : $out_dir" - +trace_dir="traces/$gpu_name/$nb_gpus/bench_pmwd" +echo "Output dir is : $out_dir" +echo "Trace dir is : $trace_dir" for pr in "${precisions[@]}"; do for g in "${grid[@]}"; do for solver in "${solvers[@]}"; do slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f done + # delete crash core dump files + rm -f core.python.* done done - - +# # zip the output files and traces +# tar -czvf $out_dir.tar.gz $out_dir +# tar -czvf $trace_dir.tar.gz $trace_dir +# # remove the output files and traces +# rm -rf $out_dir $trace_dir +# diff --git a/benchmarks/run_all_jobs.sh b/benchmarks/run_all_jobs.sh index cfee491..2e6cb3f 100755 --- a/benchmarks/run_all_jobs.sh +++ b/benchmarks/run_all_jobs.sh @@ -1,19 +1,21 @@ #!/bin/bash # Run all slurms jobs -nodes_v100=(1 2 4 8 16) -nodes_a100=(1 2 4 8 16) +nodes_v100=(1 2 4 8 16 32) +nodes_a100=(1 2 4 8 16 32) for n in ${nodes_v100[@]}; do - sbatch --nodes=$n --job-name=v100_$n-JAXPM particle_mesh_v100.slurm + sbatch --account=tkc@v100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C v100-32g --job-name=JAXPM-$n-N-v100 particle_mesh.slurm done for n in ${nodes_a100[@]}; do - sbatch --nodes=$n --job-name=a100_$n-JAXPM particle_mesh_a100.slurm + sbatch --account=tkc@a100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C a100 --job-name=JAXPM-$n-N-a100 particle_mesh.slurm done # single GPUs -sbatch --job-name=JAXPM-1GPU-V100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_v100.slurm -sbatch --job-name=JAXPM-1GPU-A100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_a100.slurm -sbatch --job-name=PMWD-v100 pmwd_v100.slurm -sbatch --job-name=PMWD-a100 pmwd_a100.slurm +sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=JAXPM-1GPU-V100 particle_mesh.slurm +sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=JAXPM-1GPU-A100 particle_mesh.slurm +sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=PMWD-1GPU-v100 pmwd_pm.slurm +sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=PMWD-1GPU-a100 pmwd_pm.slurm + + diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 9fb0e15..c16bc33 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -88,8 +88,8 @@ def get_halo_size(halo_size): def halo_exchange(x, halo_extents, halo_periods=(True, True, True)): mesh = mesh_lib.thread_resources.env.physical_mesh - if distributed and not (mesh.empty) and (halo_extents[0] > 0 - or halo_extents[1] > 0): + if distributed and not (mesh.empty) and (halo_extents > 0 + or halo_extents > 0): return jaxdecomp.halo_exchange(x, halo_extents, halo_periods) else: return x diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 7160913..7d9e9fa 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -50,15 +50,15 @@ def cic_paint_impl(mesh, displacement, weight=None): @partial(jax.jit, static_argnums=(2, )) def cic_paint(mesh, positions, halo_size=0, weight=None): - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) + halo_padding, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_padding) mesh = autoshmap(cic_paint_impl, in_specs=(P('x', 'y'), P('x', 'y'), P()), out_specs=P('x', 'y'))(mesh, positions, weight) mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True, True)) - mesh = slice_unpad(mesh, halo_size) + halo_extents=halo_size // 2, + halo_periods=True) + mesh = slice_unpad(mesh, halo_padding) return mesh @@ -95,11 +95,11 @@ def cic_read_impl(mesh, displacement): @partial(jax.jit, static_argnums=(2, )) def cic_read(mesh, displacement, halo_size=0): - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) + halo_padding, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_padding) mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True, True)) + halo_extents=halo_size//2, + halo_periods=True) displacement = autoshmap(cic_read_impl, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))(mesh, displacement) @@ -160,16 +160,16 @@ def cic_paint_dx_impl(displacements, halo_size): @partial(jax.jit, static_argnums=(1, )) def cic_paint_dx(displacements, halo_size=0): - halo_size, halo_extents = get_halo_size(halo_size) + halo_padding, halo_extents = get_halo_size(halo_size) - mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), + mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_padding), in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(displacements) mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True, True)) - mesh = slice_unpad(mesh, halo_size) + halo_extents=halo_size//2, + halo_periods=True) + mesh = slice_unpad(mesh, halo_padding) return mesh @@ -194,12 +194,12 @@ def cic_read_dx_impl(mesh , halo_size): @partial(jax.jit, static_argnums=(1, )) def cic_read_dx(mesh, halo_size=0): # return mesh - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) + halo_padding, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_padding) mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True, True)) - displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size), + halo_extents=halo_size//2, + halo_periods=True) + displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_padding), in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(mesh) diff --git a/scripts/particle_mesh.slurm b/scripts/particle_mesh.slurm index 51332a5..e8924d1 100644 --- a/scripts/particle_mesh.slurm +++ b/scripts/particle_mesh.slurm @@ -62,8 +62,8 @@ module load nvidia-nsight-systems/2024.1.1.59 echo "The number of nodes allocated for this job is: $num_nodes" echo "The number of GPUs allocated for this job is: $nb_gpus" -export EQX_ON_ERROR=nan -export CUDA_ALLOC=1 +export ENABLE_PERFO_STEP=NVTX +export MPI4JAX_USE_CUDA_MPI=1 function profile_python() { if [ $# -lt 1 ]; then @@ -122,6 +122,7 @@ set -x declare -A pdims_table # Define the table +pdims_table[1]="1x1" pdims_table[4]="2x2 1x4" pdims_table[8]="2x4 1x8" pdims_table[16]="2x8 1x16"