diff --git a/benchmarks/bench_pm.py b/benchmarks/bench_pm.py index 2bf4534..5e0c3a9 100644 --- a/benchmarks/bench_pm.py +++ b/benchmarks/bench_pm.py @@ -26,14 +26,17 @@ from jax.sharding import PartitionSpec as P from jaxpm.kernels import interpolate_power_spectrum from jaxpm.painting import cic_paint_dx from jaxpm.pm import linear_field, lpt, make_ode_fn - +from jax import make_jaxpr def run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, - pdims=None): + hlo_print, + trace, + pdims=None, + output_path="."): @jax.jit def simulate(omega_c, sigma8): @@ -60,7 +63,6 @@ def run_simulation(mesh_shape, solver = Tsit5() elif solver_choice == "lpt": lpt_field = cic_paint_dx(dx, halo_size=halo_size) - print(f"TYPE of lpt_field: {type(lpt_field)}") return lpt_field, {"num_steps": 0} else: raise ValueError( @@ -92,6 +94,7 @@ def run_simulation(mesh_shape, def run(): # Warm start +<<<<<<< HEAD chrono_fun = Timer() RangePush("warmup") final_field, stats = chrono_fun.chrono_jit(simulate, @@ -108,6 +111,39 @@ def run_simulation(mesh_shape, ndarray_arg=0) RangePop() return final_field, stats, chrono_fun +======= + if hlo_print: + jaxpr = make_jaxpr(simulate)(0.32, 0.8) + lowered = jax.jit(simulate).lower(0.32, 0.8) + lower_as_text = lowered.as_text() + compiled = lowered.compile() + compiled_again = jax.jit(simulate).lower(0.32, 0.8).compile() + return jaxpr , compiled , compiled_again + elif trace: + jit_output = f"{output_path}/jit_trace" + first_run_output = f"{output_path}/first_run_trace" + second_run_output = f"{output_path}/second_run_trace" + with jax.profiler.trace(jit_output , create_perfetto_trace=True): + final_field, stats = simulate(0.32, 0.8) + final_field.block_until_ready() + with jax.profiler.trace(first_run_output , create_perfetto_trace=True): + final_field, stats = simulate(0.32, 0.8) + final_field.block_until_ready() + with jax.profiler.trace(second_run_output , create_perfetto_trace=True): + final_field, stats = simulate(0.32, 0.8) + final_field.block_until_ready() + else: + chrono_fun = Timer() + RangePush("warmup") + final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0) + RangePop() + sync_global_devices("warmup") + for i in range(iterations): + RangePush(f"sim iter {i}") + final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0) + RangePop() + return final_field, stats, chrono_fun +>>>>>>> glab/ASKabalan/jaxdecomp_proto if jax.device_count() > 1: devices = mesh_utils.create_device_mesh(pdims) @@ -151,7 +187,7 @@ if __name__ == "__main__": '--halo_size', type=int, help='Halo size', - required=True) + default=None) parser.add_argument('-s', '--solver', type=str, @@ -161,12 +197,7 @@ if __name__ == "__main__": "LeapfrogMidpoint", "leapfrogmidpoint", "lfm", "lpt" ], - required=True) - parser.add_argument('-i', - '--iterations', - type=int, - help='Number of iterations', - default=10) + default="lpt") parser.add_argument('-o', '--output_path', type=str, @@ -181,15 +212,39 @@ if __name__ == "__main__": type=int, help='Number of nodes', default=1) +<<<<<<< HEAD +======= + parser.add_argument('-i', + '--iterations', + type=int, + help='Number of iterations', + default=10) + group = parser.add_mutually_exclusive_group() + group.add_argument('-hlo', + '--hlo_print', + action='store_true', + help='Print hlo generated by XLA') + group.add_argument('-t', + '--trace', + action='store_true', + help='Profile using tensorboard') + +>>>>>>> glab/ASKabalan/jaxdecomp_proto args = parser.parse_args() mesh_size = args.mesh_size box_size = [args.box_size] * 3 - halo_size = args.halo_size + halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size solver_choice = args.solver iterations = args.iterations output_path = args.output_path os.makedirs(output_path, exist_ok=True) +<<<<<<< HEAD +======= + hlo_print = args.hlo_print + trace = args.trace + nb_gpus = jax.device_count() +>>>>>>> glab/ASKabalan/jaxdecomp_proto print(f"solver choice: {solver_choice}") match solver_choice: @@ -213,9 +268,11 @@ if __name__ == "__main__": if args.pdims: pdims = tuple(map(int, args.pdims.split("x"))) else: - pdims = (1, 1) + pdims = (1, jax.device_count()) + pdm_str = f"{pdims[0]}x{pdims[1]}" mesh_shape = [mesh_size] * 3 +<<<<<<< HEAD final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, @@ -226,6 +283,50 @@ if __name__ == "__main__": ) metadata = { +======= + + if trace: + trace_folder = f"{output_path}/profiling/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" + os.makedirs(trace_folder, exist_ok=True) + run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, trace_folder) + print(f"Profiling done! Check {trace_folder}") + elif hlo_print: + hlo_folder = f"{output_path}/hlo/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" + os.makedirs(hlo_folder, exist_ok=True) + jaxpr , compiled , compiled2 = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, hlo_folder) + print(f"type of memory analysis {type(compiled.memory_analysis())}") + print(f"memory analysis {compiled.memory_analysis()}") + print(f"memory analysis again {compiled2.memory_analysis()}") + jax.tree.map(lambda x: print(x), compiled.memory_analysis()) + with open(f'{hlo_folder}/hlo_jaxpm.md', 'w') as f: + f.write(f"# JAXPM HLO\n") + f.write(f"## Args: {args}\n") + f.write(f"## JAXPR is \n") + f.write(f'---\n') + f.write(f"{jaxpr}\n") + f.write(f'---\n') + f.write(f"Lowered as text is \n") + f.write(f'---\n') + # f.write(f"{lower_as_text}\n") + f.write(f'---\n') + f.write(f"Compiled is \n") + f.write(f'---\n') + f.write(f"{compiled.as_text()}\n") + f.write(f"Cost analysis is \n") + f.write(f'---\n') + f.write(f"{compiled.cost_analysis()[0]['flops']}\n") + f.write(f'---\n') + f.write(f"Memory analysis is \n") + f.write(f'---\n') + f.write(f"{compiled.memory_analysis()}\n") + f.write(f'---\n') + + print(f"Saved HLO to {hlo_folder}") + else: + final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, output_path) + print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}") + metadata = { +>>>>>>> glab/ASKabalan/jaxdecomp_proto 'rank': rank, 'function_name': f'JAXPM-{solver_choice}', 'precision': args.precision, @@ -236,6 +337,7 @@ if __name__ == "__main__": 'py': str(pdims[1]), 'backend': 'NCCL', 'nodes': str(args.nodes) +<<<<<<< HEAD } # Print the results to a CSV file chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) @@ -254,8 +356,27 @@ if __name__ == "__main__": if args.save_fields: np.save(f'{field_folder}/final_field_0_{rank}.npy', final_field.addressable_data(0)) +======= + } + # Print the results to a CSV file + chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) - print(f"Finished! ") - print(f"Stats {stats}") - print(f"Saving to {output_path}/jax_pm_benchmark.csv") - print(f"Saving field and logs in {field_folder}") + # Save the final field +>>>>>>> glab/ASKabalan/jaxdecomp_proto + + field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" + os.makedirs(field_folder, exist_ok=True) + with open(f'{field_folder}/jaxpm.log', 'w') as f: + f.write(f"Args: {args}\n") + f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") + for i , time in enumerate(chrono_fun.times): + f.write(f"Time {i}: {time:.4f} ms\n") + f.write(f"Stats: {stats}\n") + if args.save_fields: + np.save(f'{field_folder}/final_field_0_{rank}.npy', + final_field.addressable_data(0)) + + print(f"Finished! ") + print(f"Stats {stats}") + print(f"Saving to {output_path}/jax_pm_benchmark.csv") + print(f"Saving field and logs in {field_folder}") diff --git a/benchmarks/particle_mesh_a100.slurm b/benchmarks/particle_mesh.slurm similarity index 65% rename from benchmarks/particle_mesh_a100.slurm rename to benchmarks/particle_mesh.slurm index a94019c..330fa1a 100644 --- a/benchmarks/particle_mesh_a100.slurm +++ b/benchmarks/particle_mesh.slurm @@ -1,40 +1,15 @@ #!/bin/bash -########################################## -## SELECT EITHER tkc@a100 OR tkc@v100 ## -########################################## -#SBATCH --account tkc@a100 -########################################## -#SBATCH --job-name=1N-FFT-Mesh # nom du job -# Il est possible d'utiliser une autre partition que celle par default -# en activant l'une des 5 directives suivantes : -########################################## -## SELECT EITHER a100 or v100-32g ## -########################################## -#SBATCH -C a100 -########################################## -#****************************************** -########################################## -## SELECT Number of nodes and GPUs per node -## For A100 ntasks-per-node and gres=gpu should be 8 -## For V100 ntasks-per-node and gres=gpu should be 4 -########################################## -#SBATCH --nodes=1 # nombre de noeud -#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud) -#SBATCH --gres=gpu:8 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) -########################################## -## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant -## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant -## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: -########################################## +############################################################################################################################## +# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm +############################################################################################################################## +#SBATCH --job-name=Particle-Mesh # nom du job #SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) -########################################## -# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm #SBATCH --hint=nomultithread # hyperthreading desactive #SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) #SBATCH --output=%x_%N_a100.out # nom du fichier de sortie -#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) +#SBATCH --error=%x_%N_a100.err # nom du fichier d'erreur (ici commun avec la sortie) +#SBATCH --exclusive # ressources dediees ##SBATCH --qos=qos_gpu-dev -## SBATCH --exclusive # ressources dediees # Nettoyage des modules charges en interactif et herites par defaut num_nodes=$SLURM_JOB_NUM_NODES num_gpu_per_node=$SLURM_NTASKS_PER_NODE @@ -44,12 +19,11 @@ nb_gpus=$(( num_nodes * num_gpu_per_node)) module purge -echo "Job constraint: $SLURM_JOB_CONSTRAINT" echo "Job partition: $SLURM_JOB_PARTITION" # Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" # pour avoir acces aux modules compatibles avec cette partition -if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then +if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then module load cpuarch/amd source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate gpu_name=a100 @@ -66,8 +40,10 @@ module load nvidia-nsight-systems/2024.1.1.59 echo "The number of nodes allocated for this job is: $num_nodes" echo "The number of GPUs allocated for this job is: $nb_gpus" +export EQX_ON_ERROR=nan export ENABLE_PERFO_STEP=NVTX export MPI4JAX_USE_CUDA_MPI=1 + function profile_python() { if [ $# -lt 1 ]; then echo "Usage: profile_python [arguments for the script]" @@ -75,14 +51,14 @@ function profile_python() { fi local script_name=$(basename "$1" .py) - local output_dir="prof_traces/$script_name" + local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name" local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then local args=$(echo "${@:2}" | tr ' ' '_') # Remove characters '/' and '-' from folder name args=$(echo "$args" | tr -d '/-') - output_dir="prof_traces/$script_name/$args" + output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args" report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" fi @@ -99,13 +75,13 @@ function run_python() { fi local script_name=$(basename "$1" .py) - local output_dir="traces/$script_name" + local output_dir="traces/$gpu_name/$nb_gpus/$script_name" if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then local args=$(echo "${@:2}" | tr ' ' '_') # Remove characters '/' and '-' from folder name args=$(echo "$args" | tr -d '/-') - output_dir="traces/$script_name/$args" + output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args" fi mkdir -p "$output_dir" @@ -142,12 +118,11 @@ pdims_table[8]="2x4 1x8 8x1 4x2" pdims_table[16]="4x4 1x16 16x1" pdims_table[32]="4x8 8x4 1x32 32x1" pdims_table[64]="8x8 16x4 1x64 64x1" -pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2" -pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4" - +pdims_table[128]="8x16 16x8 1x128 128x1" +pdims_table[256]="16x16 1x256 256x1" # mpch=(128 256 512 1024 2048 4096) -grid=(256 512 1024 2048 4096) +grid=(256 512 1024 2048 4096 8192) precisions=(float32 float64) pdim="${pdims_table[$nb_gpus]}" solvers=(lpt lfm) @@ -164,8 +139,9 @@ fi # GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 out_dir="pm_prof/$gpu_name/$nb_gpus" - -echo "Output dir is : $out_dir" +trace_dir="traces/$gpu_name/$nb_gpus/bench_pm" +echo "Output dir is : $out_dir" +echo "Trace dir is : $trace_dir" for pr in "${precisions[@]}"; do for g in "${grid[@]}"; do @@ -175,5 +151,7 @@ for pr in "${precisions[@]}"; do slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes done done + # delete crash core dump files + rm -f core.python.* done done diff --git a/benchmarks/pmwd_v100.slurm b/benchmarks/pmwd_pm.slurm similarity index 64% rename from benchmarks/pmwd_v100.slurm rename to benchmarks/pmwd_pm.slurm index 9ca5f89..171a7b3 100644 --- a/benchmarks/pmwd_v100.slurm +++ b/benchmarks/pmwd_pm.slurm @@ -1,40 +1,15 @@ #!/bin/bash -########################################## -## SELECT EITHER tkc@a100 OR tkc@v100 ## -########################################## -#SBATCH --account tkc@v100 -########################################## -#SBATCH --job-name=16N-V100Particle-Mesh # nom du job -# Il est possible d'utiliser une autre partition que celle par default -# en activant l'une des 5 directives suivantes : -########################################## -## SELECT EITHER a100 or v100-32g ## -########################################## -#SBATCH -C v100-32g -########################################## -#****************************************** -########################################## -## SELECT Number of nodes and GPUs per node -## For A100 ntasks-per-node and gres=gpu should be 8 -## For V100 ntasks-per-node and gres=gpu should be 4 -########################################## -#SBATCH --nodes=1 # nombre de noeud -#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud) -#SBATCH --gres=gpu:1 # nombre de GPU par nÅ“ud (max 8 avec gpu_p2, gpu_p5) -########################################## -## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant -## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant -## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache: -########################################## +############################################################################################################################## +# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm +############################################################################################################################## +#SBATCH --job-name=Particle-Mesh # nom du job #SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) -########################################## -# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm #SBATCH --hint=nomultithread # hyperthreading desactive -#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS) +#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) #SBATCH --output=%x_%N_a100.out # nom du fichier de sortie #SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) -#SBATCH --qos=qos_gpu-dev #SBATCH --exclusive # ressources dediees +##SBATCH --qos=qos_gpu-dev # Nettoyage des modules charges en interactif et herites par defaut num_nodes=$SLURM_JOB_NUM_NODES num_gpu_per_node=$SLURM_NTASKS_PER_NODE @@ -44,12 +19,11 @@ nb_gpus=$(( num_nodes * num_gpu_per_node)) module purge -echo "Job constraint: $SLURM_JOB_CONSTRAINT" echo "Job partition: $SLURM_JOB_PARTITION" # Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" # pour avoir acces aux modules compatibles avec cette partition -if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then +if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then module load cpuarch/amd source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate gpu_name=a100 @@ -67,9 +41,9 @@ echo "The number of nodes allocated for this job is: $num_nodes" echo "The number of GPUs allocated for this job is: $nb_gpus" export EQX_ON_ERROR=nan -export CUDA_ALLOC=1 export ENABLE_PERFO_STEP=NVTX export MPI4JAX_USE_CUDA_MPI=1 + function profile_python() { if [ $# -lt 1 ]; then echo "Usage: profile_python [arguments for the script]" @@ -77,14 +51,14 @@ function profile_python() { fi local script_name=$(basename "$1" .py) - local output_dir="prof_traces/$script_name" + local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name" local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then local args=$(echo "${@:2}" | tr ' ' '_') # Remove characters '/' and '-' from folder name args=$(echo "$args" | tr -d '/-') - output_dir="prof_traces/$script_name/$args" + output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args" report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" fi @@ -101,13 +75,13 @@ function run_python() { fi local script_name=$(basename "$1" .py) - local output_dir="traces/$script_name" + local output_dir="traces/$gpu_name/$nb_gpus/$script_name" if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then local args=$(echo "${@:2}" | tr ' ' '_') # Remove characters '/' and '-' from folder name args=$(echo "$args" | tr -d '/-') - output_dir="traces/$script_name/$args" + output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args" fi mkdir -p "$output_dir" @@ -116,6 +90,7 @@ function run_python() { } + # run or profile function slaunch() { @@ -136,13 +111,9 @@ export TMPDIR=$JOBSCRATCH # faire pointer le répertoire /tmp/nvidia vers TMPDIR ln -s $JOBSCRATCH /tmp/nvidia - - - # mpch=(128 256 512 1024 2048 4096) -grid=(256 512 1024 2048 4096) +grid=(256 512 1024 2048 4096 8192) precisions=(float32 float64) -pdim="${pdims_table[$nb_gpus]}" solvers=(lpt lfm) # GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 @@ -154,14 +125,23 @@ else fi out_dir="pm_prof/$gpu_name/$nb_gpus" - -echo "Output dir is : $out_dir" - +trace_dir="traces/$gpu_name/$nb_gpus/bench_pmwd" +echo "Output dir is : $out_dir" +echo "Trace dir is : $trace_dir" for pr in "${precisions[@]}"; do for g in "${grid[@]}"; do for solver in "${solvers[@]}"; do slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f done + # delete crash core dump files + rm -f core.python.* done done + +# # zip the output files and traces +# tar -czvf $out_dir.tar.gz $out_dir +# tar -czvf $trace_dir.tar.gz $trace_dir +# # remove the output files and traces +# rm -rf $out_dir $trace_dir +# diff --git a/benchmarks/run_all_jobs.sh b/benchmarks/run_all_jobs.sh index cfee491..2e6cb3f 100755 --- a/benchmarks/run_all_jobs.sh +++ b/benchmarks/run_all_jobs.sh @@ -1,19 +1,21 @@ #!/bin/bash # Run all slurms jobs -nodes_v100=(1 2 4 8 16) -nodes_a100=(1 2 4 8 16) +nodes_v100=(1 2 4 8 16 32) +nodes_a100=(1 2 4 8 16 32) for n in ${nodes_v100[@]}; do - sbatch --nodes=$n --job-name=v100_$n-JAXPM particle_mesh_v100.slurm + sbatch --account=tkc@v100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C v100-32g --job-name=JAXPM-$n-N-v100 particle_mesh.slurm done for n in ${nodes_a100[@]}; do - sbatch --nodes=$n --job-name=a100_$n-JAXPM particle_mesh_a100.slurm + sbatch --account=tkc@a100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C a100 --job-name=JAXPM-$n-N-a100 particle_mesh.slurm done # single GPUs -sbatch --job-name=JAXPM-1GPU-V100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_v100.slurm -sbatch --job-name=JAXPM-1GPU-A100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_a100.slurm -sbatch --job-name=PMWD-v100 pmwd_v100.slurm -sbatch --job-name=PMWD-a100 pmwd_a100.slurm +sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=JAXPM-1GPU-V100 particle_mesh.slurm +sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=JAXPM-1GPU-A100 particle_mesh.slurm +sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=PMWD-1GPU-v100 pmwd_pm.slurm +sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=PMWD-1GPU-a100 pmwd_pm.slurm + + diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 54377cb..5498b3d 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -91,8 +91,8 @@ def get_halo_size(halo_size): def halo_exchange(x, halo_extents, halo_periods=(True, True, True)): mesh = mesh_lib.thread_resources.env.physical_mesh - if distributed and not (mesh.empty) and (halo_extents[0] > 0 - or halo_extents[1] > 0): + if distributed and not (mesh.empty) and (halo_extents > 0 + or halo_extents > 0): return jaxdecomp.halo_exchange(x, halo_extents, halo_periods) else: return x diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 975e43c..597f3aa 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -50,15 +50,15 @@ def cic_paint_impl(mesh, displacement, weight=None): @partial(jax.jit, static_argnums=(2, )) def cic_paint(mesh, positions, halo_size=0, weight=None): - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) + halo_padding, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_padding) mesh = autoshmap(cic_paint_impl, in_specs=(P('x', 'y'), P('x', 'y'), P()), out_specs=P('x', 'y'))(mesh, positions, weight) mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True, True)) - mesh = slice_unpad(mesh, halo_size) + halo_extents=halo_size // 2, + halo_periods=True) + mesh = slice_unpad(mesh, halo_padding) return mesh @@ -95,11 +95,11 @@ def cic_read_impl(mesh, displacement): @partial(jax.jit, static_argnums=(2, )) def cic_read(mesh, displacement, halo_size=0): - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) + halo_padding, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_padding) mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True, True)) + halo_extents=halo_size//2, + halo_periods=True) displacement = autoshmap(cic_read_impl, in_specs=(P('x', 'y'), P('x', 'y')), out_specs=P('x', 'y'))(mesh, displacement) @@ -159,17 +159,24 @@ def cic_paint_dx_impl(displacements, halo_size): @partial(jax.jit, static_argnums=(1, )) def cic_paint_dx(displacements, halo_size=0): +<<<<<<< HEAD halo_size, halo_extents = get_halo_size(halo_size) mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), +======= + + halo_padding, halo_extents = get_halo_size(halo_size) + + mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_padding), +>>>>>>> glab/ASKabalan/jaxdecomp_proto in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(displacements) mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True, True)) - mesh = slice_unpad(mesh, halo_size) + halo_extents=halo_size//2, + halo_periods=True) + mesh = slice_unpad(mesh, halo_padding) return mesh @@ -196,12 +203,18 @@ def cic_read_dx_impl(mesh, halo_size): @partial(jax.jit, static_argnums=(1, )) def cic_read_dx(mesh, halo_size=0): # return mesh - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) + halo_padding, halo_extents = get_halo_size(halo_size) + mesh = slice_pad(mesh, halo_padding) mesh = halo_exchange(mesh, +<<<<<<< HEAD halo_extents=halo_extents, halo_periods=(True, True, True)) displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), +======= + halo_extents=halo_size//2, + halo_periods=True) + displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_padding), +>>>>>>> glab/ASKabalan/jaxdecomp_proto in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(mesh) diff --git a/scripts/particle_mesh.slurm b/scripts/particle_mesh.slurm index 2585d5d..9c4cbac 100644 --- a/scripts/particle_mesh.slurm +++ b/scripts/particle_mesh.slurm @@ -62,8 +62,8 @@ module load nvidia-nsight-systems/2024.1.1.59 echo "The number of nodes allocated for this job is: $num_nodes" echo "The number of GPUs allocated for this job is: $nb_gpus" -export EQX_ON_ERROR=nan -export CUDA_ALLOC=1 +export ENABLE_PERFO_STEP=NVTX +export MPI4JAX_USE_CUDA_MPI=1 function profile_python() { if [ $# -lt 1 ]; then @@ -122,6 +122,7 @@ set -x declare -A pdims_table # Define the table +pdims_table[1]="1x1" pdims_table[4]="2x2 1x4" pdims_table[8]="2x4 1x8" pdims_table[16]="2x8 1x16"