diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b716a49..2c2299d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,7 +29,7 @@ jobs: run: | sudo apt-get install -y libopenmpi-dev python -m pip install --upgrade pip - pip install jax=0.4.35 + pip install jax=0.4.35 pip install .[test] - name: Run Single Device Tests diff --git a/README.md b/README.md index 8af7a22..927af5b 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Provide a modern infrastructure to support differentiable PM N-body simulations - Any order forward and backward automatic differentiation - Support automated batching using `vmap` - Compatibility with external optimizer libraries like `optax` -- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with the latex `JAX v0.4.35` +- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35` ## Open development and use diff --git a/benchmarks/bench_pm.py b/benchmarks/bench_pm.py deleted file mode 100644 index 5f25aad..0000000 --- a/benchmarks/bench_pm.py +++ /dev/null @@ -1,270 +0,0 @@ -import os - -os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax -import jax - -jax.distributed.initialize() - -rank = jax.process_index() -size = jax.process_count() - -import argparse -import time - -import jax.numpy as jnp -import jax_cosmo as jc -import numpy as np -from cupy.cuda.nvtx import RangePop, RangePush -from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, - PIDController, SaveAt, Tsit5, diffeqsolve) -from hpc_plotter.timer import Timer -from jax.experimental import mesh_utils -from jax.experimental.multihost_utils import sync_global_devices -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P - -from jaxpm.kernels import interpolate_power_spectrum -from jaxpm.painting import cic_paint_dx -from jaxpm.pm import linear_field, lpt, make_ode_fn - - -def run_simulation(mesh_shape, - box_size, - halo_size, - solver_choice, - iterations, - hlo_print, - trace, - pdims=None, - output_path="."): - - @jax.jit - def simulate(omega_c, sigma8): - # Create a small function to generate the matter power spectrum - k = jnp.logspace(-4, 1, 128) - pk = jc.power.linear_matter_power( - jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) - pk_fn = lambda x: interpolate_power_spectrum(x, k, pk) - - # Create initial conditions - initial_conditions = linear_field(mesh_shape, - box_size, - pk_fn, - seed=jax.random.PRNGKey(0)) - - # Create particles - cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) - dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) - if solver_choice == "Dopri5": - solver = Dopri5() - elif solver_choice == "LeapfrogMidpoint": - solver = LeapfrogMidpoint() - elif solver_choice == "Tsit5": - solver = Tsit5() - elif solver_choice == "lpt": - lpt_field = cic_paint_dx(dx, halo_size=halo_size) - return lpt_field, {"num_steps": 0} - else: - raise ValueError( - "Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.") - # Evolve the simulation forward - ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) - term = ODETerm( - lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) - - if solver_choice == "Dopri5" or solver_choice == "Tsit5": - stepsize_controller = PIDController(rtol=1e-4, atol=1e-4) - elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler": - stepsize_controller = ConstantStepSize() - res = diffeqsolve(term, - solver, - t0=0.1, - t1=1., - dt0=0.01, - y0=jnp.stack([dx, p], axis=0), - args=cosmo, - saveat=SaveAt(t1=True), - stepsize_controller=stepsize_controller) - - # Return the simulation volume at requested - state = res.ys[-1] - final_field = cic_paint_dx(state[0], halo_size=halo_size) - - return final_field, res.stats - - def run(): - # Warm start - chrono_fun = Timer() - RangePush("warmup") - final_field, stats = chrono_fun.chrono_jit(simulate, - 0.32, - 0.8, - ndarray_arg=0) - RangePop() - sync_global_devices("warmup") - for i in range(iterations): - RangePush(f"sim iter {i}") - final_field, stats = chrono_fun.chrono_fun(simulate, - 0.32, - 0.8, - ndarray_arg=0) - RangePop() - return final_field, stats, chrono_fun - - if jax.device_count() > 1: - devices = mesh_utils.create_device_mesh(pdims) - mesh = Mesh(devices.T, axis_names=('x', 'y')) - with mesh: - # Warm start - final_field, stats, chrono_fun = run() - else: - final_field, stats, chrono_fun = run() - - return final_field, stats, chrono_fun - - -if __name__ == "__main__": - - parser = argparse.ArgumentParser( - description='JAX Cosmo Simulation Benchmark') - parser.add_argument('-m', - '--mesh_size', - type=int, - help='Mesh size', - required=True) - parser.add_argument('-b', - '--box_size', - type=float, - help='Box size', - required=True) - parser.add_argument('-p', - '--pdims', - type=str, - help='Processor dimensions', - default=None) - parser.add_argument( - '-pr', - '--precision', - type=str, - help='Precision', - choices=["float32", "float64"], - ) - parser.add_argument('-hs', - '--halo_size', - type=int, - help='Halo size', - default=None) - parser.add_argument('-s', - '--solver', - type=str, - help='Solver', - choices=[ - "Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5", - "LeapfrogMidpoint", "leapfrogmidpoint", "lfm", - "lpt" - ], - default="lpt") - parser.add_argument('-o', - '--output_path', - type=str, - help='Output path', - default=".") - parser.add_argument('-f', - '--save_fields', - action='store_true', - help='Save fields') - parser.add_argument('-n', - '--nodes', - type=int, - help='Number of nodes', - default=1) - args = parser.parse_args() - mesh_size = args.mesh_size - box_size = [args.box_size] * 3 - halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size - solver_choice = args.solver - iterations = args.iterations - output_path = args.output_path - os.makedirs(output_path, exist_ok=True) - - print(f"solver choice: {solver_choice}") - match solver_choice: - case "Dopri5" | "dopri5" | "d5": - solver_choice = "Dopri5" - case "Tsit5" | "tsit5" | "t5": - solver_choice = "Tsit5" - case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm": - solver_choice = "LeapfrogMidpoint" - case "lpt": - solver_choice = "lpt" - case _: - raise ValueError( - "Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt" - ) - if args.precision == "float32": - jax.config.update("jax_enable_x64", False) - elif args.precision == "float64": - jax.config.update("jax_enable_x64", True) - - if args.pdims: - pdims = tuple(map(int, args.pdims.split("x"))) - else: - pdims = (1, jax.device_count()) - pdm_str = f"{pdims[0]}x{pdims[1]}" - - mesh_shape = [mesh_size] * 3 - - final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, - halo_size, solver_choice, - iterations, pdims) - - print( - f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}" - ) - - metadata = { - 'rank': rank, - 'function_name': f'JAXPM-{solver_choice}', - 'precision': args.precision, - 'x': str(mesh_size), - 'y': str(mesh_size), - 'z': str(stats["num_steps"]), - 'px': str(pdims[0]), - 'py': str(pdims[1]), - 'backend': 'NCCL', - 'nodes': str(args.nodes) - } - # Print the results to a CSV file - chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata) - - # Save the final field - nb_gpus = jax.device_count() - pdm_str = f"{pdims[0]}x{pdims[1]}" - field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" - os.makedirs(field_folder, exist_ok=True) - with open(f'{field_folder}/jaxpm.log', 'w') as f: - f.write(f"Args: {args}\n") - f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") - for i, time in enumerate(chrono_fun.times): - f.write(f"Time {i}: {time:.4f} ms\n") - f.write(f"Stats: {stats}\n") - if args.save_fields: - np.save(f'{field_folder}/final_field_0_{rank}.npy', - final_field.addressable_data(0)) - - field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}" - os.makedirs(field_folder, exist_ok=True) - with open(f'{field_folder}/jaxpm.log', 'w') as f: - f.write(f"Args: {args}\n") - f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") - for i, time in enumerate(chrono_fun.times): - f.write(f"Time {i}: {time:.4f} ms\n") - f.write(f"Stats: {stats}\n") - if args.save_fields: - np.save(f'{field_folder}/final_field_0_{rank}.npy', - final_field.addressable_data(0)) - - print(f"Finished! ") - print(f"Stats {stats}") - print(f"Saving to {output_path}/jax_pm_benchmark.csv") - print(f"Saving field and logs in {field_folder}") diff --git a/benchmarks/bench_pmwd.py b/benchmarks/bench_pmwd.py deleted file mode 100644 index bd11303..0000000 --- a/benchmarks/bench_pmwd.py +++ /dev/null @@ -1,159 +0,0 @@ -import os - -# Change JAX GPU memory preallocation fraction -os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95' - -import argparse - -import jax -import matplotlib.pyplot as plt -import numpy as np -from hpc_plotter.timer import Timer -from pmwd import (Configuration, Cosmology, SimpleLCDM, boltzmann, growth, - linear_modes, linear_power, lpt, nbody, scatter, white_noise) -from pmwd.pm_util import fftinv -from pmwd.spec_util import powspec -from pmwd.vis_util import simshow - - -# Simulation configuration -def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver, iterations): - - @jax.jit - def simulate(omega_m, sigma8): - - conf = Configuration(ptcl_spacing, - ptcl_grid_shape=ptcl_grid_shape, - mesh_shape=1, - lpt_order=1, - a_nbody_maxstep=1 / 91) - print(conf) - print( - f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.' - ) - - cosmo = Cosmology(conf, - A_s_1e9=2.0, - n_s=0.96, - Omega_m=omega_m, - Omega_b=sigma8, - h=0.7) - print(cosmo) - - # Boltzmann calculation - cosmo = boltzmann(cosmo, conf) - print("Boltzmann calculation completed.") - - # Generate white noise field and scale with the linear power spectrum - seed = 0 - modes = white_noise(seed, conf) - modes = linear_modes(modes, cosmo, conf) - print("Linear modes generated.") - - # Solve LPT at some early time - ptcl, obsvbl = lpt(modes, cosmo, conf) - print("LPT solved.") - - if solver == "lfm": - # N-body time integration from LPT initial conditions - ptcl, obsvbl = jax.block_until_ready( - nbody(ptcl, obsvbl, cosmo, conf)) - print("N-body time integration completed.") - - # Scatter particles to mesh to get the density field - dens = scatter(ptcl, conf) - return dens - - chrono_timer = Timer() - final_field = chrono_timer.chrono_jit(simulate, 0.3, 0.05) - - for _ in range(iterations): - final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05) - - return final_field, chrono_timer - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='PMWD Simulation') - parser.add_argument('-m', - '--mesh_size', - type=int, - help='Mesh size', - required=True) - parser.add_argument('-b', - '--box_size', - type=float, - help='Box size', - required=True) - parser.add_argument('-i', - '--iterations', - type=int, - help='Number of iterations', - default=10) - parser.add_argument('-o', - '--output_path', - type=str, - help='Output path', - default=".") - parser.add_argument('-f', - '--save_fields', - action='store_true', - help='Save fields') - parser.add_argument('-s', - '--solver', - type=str, - help='Solver', - choices=["lfm", "lpt"]) - parser.add_argument( - '-pr', - '--precision', - type=str, - help='Precision', - choices=["float32", "float64"], - ) - - args = parser.parse_args() - - mesh_shape = [args.mesh_size] * 3 - ptcl_spacing = args.box_size / args.mesh_size - iterations = args.iterations - solver = args.solver - output_path = args.output_path - if args.precision == "float32": - jax.config.update("jax_enable_x64", False) - elif args.precision == "float64": - jax.config.update("jax_enable_x64", True) - - os.makedirs(output_path, exist_ok=True) - - final_field, chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing, - solver, iterations) - print("PMWD simulation completed.") - - metadata = { - 'rank': 0, - 'function_name': f'PMWD-{solver}', - 'precision': args.precision, - 'x': str(mesh_shape[0]), - 'y': str(mesh_shape[0]), - 'z': str(mesh_shape[0]), - 'px': "1", - 'py': "1", - 'backend': 'NCCL', - 'nodes': "1" - } - chrono_fun.print_to_csv(f"{output_path}/pmwd.csv", **metadata) - field_folder = f"{output_path}/final_field/pmwd/1/{args.mesh_size}_{int(args.box_size)}/1x1/{args.solver}/halo_0" - os.makedirs(field_folder, exist_ok=True) - with open(f"{field_folder}/pmwd.log", "w") as f: - f.write(f"PMWD simulation completed.\n") - f.write(f"Args : {args}\n") - f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n") - for i, time in enumerate(chrono_fun.times): - f.write(f"Time {i}: {time:.4f} ms\n") - if args.save_fields: - np.save(f"{field_folder}/final_field_0_0.npy", final_field) - print("Fields saved.") - - print(f"saving to {output_path}/pmwd.csv") - print(f"saving field and logs to {field_folder}/pmwd.log") diff --git a/benchmarks/particle_mesh.slurm b/benchmarks/particle_mesh.slurm deleted file mode 100644 index 7e60678..0000000 --- a/benchmarks/particle_mesh.slurm +++ /dev/null @@ -1,157 +0,0 @@ -#!/bin/bash -############################################################################################################################## -# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm -############################################################################################################################## -#SBATCH --job-name=Particle-Mesh # nom du job -#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) -#SBATCH --hint=nomultithread # hyperthreading desactive -#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) -#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie -#SBATCH --error=%x_%N_a100.err # nom du fichier d'erreur (ici commun avec la sortie) -#SBATCH --exclusive # ressources dediees -##SBATCH --qos=qos_gpu-dev -# Nettoyage des modules charges en interactif et herites par defaut -num_nodes=$SLURM_JOB_NUM_NODES -num_gpu_per_node=$SLURM_NTASKS_PER_NODE -OUTPUT_FOLDER_ARGS=1 -# Calculate the number of GPUs -nb_gpus=$(( num_nodes * num_gpu_per_node)) - -module purge - -echo "Job partition: $SLURM_JOB_PARTITION" -# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" -# pour avoir acces aux modules compatibles avec cette partition - -if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then - module load cpuarch/amd - source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate - gpu_name=a100 -else - source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate - gpu_name=v100 -fi - -# Chargement des modules -module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake -module load nvidia-nsight-systems/2024.1.1.59 - - -echo "The number of nodes allocated for this job is: $num_nodes" -echo "The number of GPUs allocated for this job is: $nb_gpus" - -export EQX_ON_ERROR=nan -export ENABLE_PERFO_STEP=NVTX -export MPI4JAX_USE_CUDA_MPI=1 - -function profile_python() { - if [ $# -lt 1 ]; then - echo "Usage: profile_python [arguments for the script]" - return 1 - fi - - local script_name=$(basename "$1" .py) - local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name" - local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" - - if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then - local args=$(echo "${@:2}" | tr ' ' '_') - # Remove characters '/' and '-' from folder name - args=$(echo "$args" | tr -d '/-') - output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args" - report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" - fi - - mkdir -p "$output_dir" - mkdir -p "$report_dir" - - srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true -} - -function run_python() { - if [ $# -lt 1 ]; then - echo "Usage: run_python [arguments for the script]" - return 1 - fi - - local script_name=$(basename "$1" .py) - local output_dir="traces/$gpu_name/$nb_gpus/$script_name" - - if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then - local args=$(echo "${@:2}" | tr ' ' '_') - # Remove characters '/' and '-' from folder name - args=$(echo "$args" | tr -d '/-') - output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args" - fi - - mkdir -p "$output_dir" - - srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true -} - - -# run or profile - -function slaunch() { - run_python "$@" -} - -function plaunch() { - profile_python "$@" -} - -# Echo des commandes lancees -set -x - -# Pour ne pas utiliser le /tmp -export TMPDIR=$JOBSCRATCH -# Pour contourner un bogue dans les versions actuelles de Nsight Systems -# il est également nécessaire de créer un lien symbolique permettant de -# faire pointer le répertoire /tmp/nvidia vers TMPDIR -ln -s $JOBSCRATCH /tmp/nvidia - -declare -A pdims_table -# Define the table -pdims_table[1]="1x1" -pdims_table[4]="2x2 1x4 4x1" -pdims_table[8]="2x4 1x8 8x1 4x2" -pdims_table[16]="4x4 1x16 16x1" -pdims_table[32]="4x8 8x4 1x32 32x1" -pdims_table[64]="8x8 16x4 1x64 64x1" -pdims_table[128]="8x16 16x8 1x128 128x1" -pdims_table[256]="16x16 1x256 256x1" - -# mpch=(128 256 512 1024 2048 4096) -grid=(256 512 1024 2048 4096 8192) -precisions=(float32 float64) -pdim="${pdims_table[$nb_gpus]}" -solvers=(lpt lfm) -echo "pdims: $pdim" - -# Check if pdims is not empty -if [ -z "$pdim" ]; then - echo "pdims is empty" - echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160" - echo "Number of nodes selected: $num_nodes" - echo "Number of gpus per node: $num_gpu_per_node" - exit 1 -fi - -# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 -out_dir="pm_prof/$gpu_name/$nb_gpus" -trace_dir="traces/$gpu_name/$nb_gpus/bench_pm" -echo "Output dir is : $out_dir" -echo "Trace dir is : $trace_dir" - -for pr in "${precisions[@]}"; do - for g in "${grid[@]}"; do - for solver in "${solvers[@]}"; do - for p in $pdim; do - halo_size=$((g / 4)) - slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes - done - done - # delete crash core dump files - rm -f core.python.* - done -done diff --git a/benchmarks/pmwd_pm.slurm b/benchmarks/pmwd_pm.slurm deleted file mode 100644 index 4171a51..0000000 --- a/benchmarks/pmwd_pm.slurm +++ /dev/null @@ -1,147 +0,0 @@ -#!/bin/bash -############################################################################################################################## -# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm -############################################################################################################################## -#SBATCH --job-name=Particle-Mesh # nom du job -#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU) -#SBATCH --hint=nomultithread # hyperthreading desactive -#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS) -#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie -#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie) -#SBATCH --exclusive # ressources dediees -##SBATCH --qos=qos_gpu-dev -# Nettoyage des modules charges en interactif et herites par defaut -num_nodes=$SLURM_JOB_NUM_NODES -num_gpu_per_node=$SLURM_NTASKS_PER_NODE -OUTPUT_FOLDER_ARGS=1 -# Calculate the number of GPUs -nb_gpus=$(( num_nodes * num_gpu_per_node)) - -module purge - -echo "Job partition: $SLURM_JOB_PARTITION" -# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5" -# pour avoir acces aux modules compatibles avec cette partition - -if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then - module load cpuarch/amd - source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate - gpu_name=a100 -else - source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate - gpu_name=v100 -fi - -# Chargement des modules -module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake -module load nvidia-nsight-systems/2024.1.1.59 - - -echo "The number of nodes allocated for this job is: $num_nodes" -echo "The number of GPUs allocated for this job is: $nb_gpus" - -export EQX_ON_ERROR=nan -export ENABLE_PERFO_STEP=NVTX -export MPI4JAX_USE_CUDA_MPI=1 - -function profile_python() { - if [ $# -lt 1 ]; then - echo "Usage: profile_python [arguments for the script]" - return 1 - fi - - local script_name=$(basename "$1" .py) - local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name" - local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name" - - if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then - local args=$(echo "${@:2}" | tr ' ' '_') - # Remove characters '/' and '-' from folder name - args=$(echo "$args" | tr -d '/-') - output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args" - report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args" - fi - - mkdir -p "$output_dir" - mkdir -p "$report_dir" - - srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true -} - -function run_python() { - if [ $# -lt 1 ]; then - echo "Usage: run_python [arguments for the script]" - return 1 - fi - - local script_name=$(basename "$1" .py) - local output_dir="traces/$gpu_name/$nb_gpus/$script_name" - - if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then - local args=$(echo "${@:2}" | tr ' ' '_') - # Remove characters '/' and '-' from folder name - args=$(echo "$args" | tr -d '/-') - output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args" - fi - - mkdir -p "$output_dir" - - srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true -} - - - -# run or profile - -function slaunch() { - run_python "$@" -} - -function plaunch() { - profile_python "$@" -} - -# Echo des commandes lancees -set -x - -# Pour ne pas utiliser le /tmp -export TMPDIR=$JOBSCRATCH -# Pour contourner un bogue dans les versions actuelles de Nsight Systems -# il est également nécessaire de créer un lien symbolique permettant de -# faire pointer le répertoire /tmp/nvidia vers TMPDIR -ln -s $JOBSCRATCH /tmp/nvidia - -# mpch=(128 256 512 1024 2048 4096) -grid=(256 512 1024 2048 4096 8192) -precisions=(float32 float64) -solvers=(lpt lfm) - -# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100 - -if [ $num_gpu_per_node -eq 8 ]; then - gpu_name="a100" -else - gpu_name="v100" -fi - -out_dir="pm_prof/$gpu_name/$nb_gpus" -trace_dir="traces/$gpu_name/$nb_gpus/bench_pmwd" -echo "Output dir is : $out_dir" -echo "Trace dir is : $trace_dir" - -for pr in "${precisions[@]}"; do - for g in "${grid[@]}"; do - for solver in "${solvers[@]}"; do - slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f - done - # delete crash core dump files - rm -f core.python.* - done -done - -# # zip the output files and traces -# tar -czvf $out_dir.tar.gz $out_dir -# tar -czvf $trace_dir.tar.gz $trace_dir -# # remove the output files and traces -# rm -rf $out_dir $trace_dir -# diff --git a/benchmarks/run_all_jobs.sh b/benchmarks/run_all_jobs.sh deleted file mode 100755 index b0e3815..0000000 --- a/benchmarks/run_all_jobs.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -# Run all slurms jobs -nodes_v100=(1 2 4 8 16 32) -nodes_a100=(1 2 4 8 16 32) - - -for n in ${nodes_v100[@]}; do - sbatch --account=tkc@v100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C v100-32g --job-name=JAXPM-$n-N-v100 particle_mesh.slurm -done - -for n in ${nodes_a100[@]}; do - sbatch --account=tkc@a100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C a100 --job-name=JAXPM-$n-N-a100 particle_mesh.slurm -done - -# single GPUs -sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=JAXPM-1GPU-V100 particle_mesh.slurm -sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=JAXPM-1GPU-A100 particle_mesh.slurm -sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=PMWD-1GPU-v100 pmwd_pm.slurm -sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=PMWD-1GPU-a100 pmwd_pm.slurm diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 44177e9..721a971 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -10,13 +10,13 @@ import jax.numpy as jnp import jaxdecomp from jax import lax from jax.experimental.shard_map import shard_map -from jax.sharding import Mesh +from jax.sharding import AbstractMesh, Mesh from jax.sharding import PartitionSpec as P def autoshmap( f: Callable, - gpu_mesh: Mesh | None, + gpu_mesh: Mesh | AbstractMesh | None, in_specs: Specs, out_specs: Specs, check_rep: bool = False, @@ -122,7 +122,7 @@ def get_local_shape(mesh_shape, sharding=None): ] -def __axis_names(spec): +def _axis_names(spec): if len(spec) == 1: x_axis, = spec y_axis = None @@ -147,7 +147,7 @@ def uniform_particles(mesh_shape, sharding=None): if gpu_mesh is not None and not (gpu_mesh.empty): local_mesh_shape = get_local_shape(mesh_shape, sharding) spec = sharding.spec - x_axis, y_axis, single_axis = __axis_names(spec) + x_axis, y_axis, single_axis = _axis_names(spec) def particles(): x_indx = lax.axis_index(x_axis) @@ -178,7 +178,7 @@ def normal_field(mesh_shape, seed, sharding=None): # to make the code work both in multi host and single controller we can do this trick keys = jax.random.split(seed, size) spec = sharding.spec - x_axis, y_axis, single_axis = __axis_names(spec) + x_axis, y_axis, single_axis = _axis_names(spec) def normal(keys, shape, dtype): idx = lax.axis_index(x_axis) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index db695e3..3083f08 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -3,6 +3,7 @@ from functools import partial import jax import jax.lax as lax import jax.numpy as jnp +from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange, @@ -11,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk from jaxpm.painting_utils import gather, scatter -def cic_paint_impl(grid_mesh, positions, weight=None): +def _cic_paint_impl(grid_mesh, positions, weight=None): """ Paints positions onto mesh mesh: [nx, ny, nz] displacement field: [nx, ny, nz, 3] @@ -54,9 +55,9 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): halo_size, halo_extents = get_halo_size(halo_size, sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding) - gpu_mesh = sharding.mesh if sharding is not None else None - spec = sharding.spec if sharding is not None else P() - grid_mesh = autoshmap(cic_paint_impl, + gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None + spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + grid_mesh = autoshmap(_cic_paint_impl, gpu_mesh=gpu_mesh, in_specs=(spec, spec, P()), out_specs=spec)(grid_mesh, positions, weight) @@ -68,7 +69,7 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): return grid_mesh -def cic_read_impl(grid_mesh, positions): +def _cic_read_impl(grid_mesh, positions): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [nx,ny,nz, 3] @@ -110,10 +111,10 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None): grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents, halo_periods=(True, True)) - gpu_mesh = sharding.mesh if sharding is not None else None - spec = sharding.spec if sharding is not None else P() + gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None + spec = sharding.spec if isinstance(sharding, NamedSharding) else P() - displacement = autoshmap(cic_read_impl, + displacement = autoshmap(_cic_read_impl, gpu_mesh=gpu_mesh, in_specs=(spec, spec), out_specs=spec)(grid_mesh, positions) @@ -150,7 +151,7 @@ def cic_paint_2d(mesh, positions, weight): return mesh -def cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): +def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] @@ -187,9 +188,9 @@ def cic_paint_dx(displacements, halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) - gpu_mesh = sharding.mesh if sharding is not None else None - spec = sharding.spec if sharding is not None else P() - grid_mesh = autoshmap(partial(cic_paint_dx_impl, + gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None + spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + grid_mesh = autoshmap(partial(_cic_paint_dx_impl, halo_size=halo_size, weight=weight, chunk_size=chunk_size), @@ -204,7 +205,7 @@ def cic_paint_dx(displacements, return grid_mesh -def cic_read_dx_impl(grid_mesh, disp, halo_size): +def _cic_read_dx_impl(grid_mesh, disp, halo_size): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] @@ -233,9 +234,9 @@ def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None): grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents, halo_periods=(True, True)) - gpu_mesh = sharding.mesh if sharding is not None else None - spec = sharding.spec if sharding is not None else P() - displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), + gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None + spec = sharding.spec if isinstance(sharding, NamedSharding) else P() + displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size), gpu_mesh=gpu_mesh, in_specs=(spec), out_specs=spec)(grid_mesh, disp) diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index cec0bce..fd683ab 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -5,8 +5,8 @@ initialize_distributed() # ignore : E402 import jax # noqa : E402 import jax.numpy as jnp # noqa : E402 import pytest # noqa : E402 -from diffrax import (Dopri5, ODETerm, PIDController, SaveAt, # noqa : E402 - diffeqsolve) +from diffrax import SaveAt # noqa : E402 +from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve from helpers import MSE # noqa : E402 from jax import lax # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402