mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-18 17:10:54 +00:00
update benchmark and add slurm
This commit is contained in:
parent
ed8cf8e532
commit
0216837033
2 changed files with 226 additions and 24 deletions
|
@ -15,8 +15,8 @@ import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from cupy.cuda.nvtx import RangePop, RangePush
|
from cupy.cuda.nvtx import RangePop, RangePush
|
||||||
from diffrax import (Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt,
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
diffeqsolve)
|
PIDController, SaveAt, Tsit5, diffeqsolve)
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.experimental.multihost_utils import sync_global_devices
|
from jax.experimental.multihost_utils import sync_global_devices
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
@ -29,7 +29,8 @@ from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||||
|
|
||||||
def chrono_fun(fun, *args):
|
def chrono_fun(fun, *args):
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
out = fun(*args).block_until_ready()
|
out = fun(*args)
|
||||||
|
out[0].block_until_ready()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return out, end - start
|
return out, end - start
|
||||||
|
|
||||||
|
@ -59,18 +60,22 @@ def run_simulation(mesh_shape,
|
||||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
||||||
|
|
||||||
# Evolve the simulation forward
|
|
||||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
|
||||||
term = ODETerm(
|
|
||||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
|
||||||
|
|
||||||
if solver_choice == "Dopri5":
|
if solver_choice == "Dopri5":
|
||||||
solver = Dopri5()
|
solver = Dopri5()
|
||||||
elif solver_choice == "LeapfrogMidpoint":
|
elif solver_choice == "LeapfrogMidpoint":
|
||||||
solver = LeapfrogMidpoint()
|
solver = LeapfrogMidpoint()
|
||||||
|
elif solver_choice == "Tsit5":
|
||||||
|
solver = Tsit5()
|
||||||
|
elif solver_choice == "lpt":
|
||||||
|
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
|
||||||
|
return lpt_field, {"num_steps": 0, "Solver": "LPT"}
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
|
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
|
||||||
|
# Evolve the simulation forward
|
||||||
|
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
||||||
|
term = ODETerm(
|
||||||
|
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||||
|
|
||||||
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
|
||||||
res = diffeqsolve(term,
|
res = diffeqsolve(term,
|
||||||
|
@ -93,17 +98,17 @@ def run_simulation(mesh_shape,
|
||||||
# Warm start
|
# Warm start
|
||||||
times = []
|
times = []
|
||||||
RangePush("warmup")
|
RangePush("warmup")
|
||||||
final_field, stats, warmup_time = chrono_fun(simulate, 0.32, 0.8)
|
(final_field, stats), warmup_time = chrono_fun(simulate, 0.32, 0.8)
|
||||||
RangePop()
|
RangePop()
|
||||||
sync_global_devices("warmup")
|
sync_global_devices("warmup")
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
RangePush(f"sim iter {i}")
|
RangePush(f"sim iter {i}")
|
||||||
final_field, stats, sim_time = chrono_fun(simulate, 0.32, 0.8)
|
(final_field, stats), sim_time = chrono_fun(simulate, 0.32, 0.8)
|
||||||
RangePop()
|
RangePop()
|
||||||
times.append(sim_time)
|
times.append(sim_time)
|
||||||
return stats, warmup_time, times, final_field
|
return stats, warmup_time, times, final_field
|
||||||
|
|
||||||
if jax.device_count() > 1 and pdims:
|
if jax.device_count() > 1:
|
||||||
devices = mesh_utils.create_device_mesh(pdims)
|
devices = mesh_utils.create_device_mesh(pdims)
|
||||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||||
with mesh:
|
with mesh:
|
||||||
|
@ -134,7 +139,7 @@ if __name__ == "__main__":
|
||||||
type=str,
|
type=str,
|
||||||
help='Processor dimensions',
|
help='Processor dimensions',
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument('-h',
|
parser.add_argument('-hs',
|
||||||
'--halo_size',
|
'--halo_size',
|
||||||
type=int,
|
type=int,
|
||||||
help='Halo size',
|
help='Halo size',
|
||||||
|
@ -143,7 +148,11 @@ if __name__ == "__main__":
|
||||||
'--solver',
|
'--solver',
|
||||||
type=str,
|
type=str,
|
||||||
help='Solver',
|
help='Solver',
|
||||||
choices=["Dopri5", "LeapfrogMidpoint"],
|
choices=[
|
||||||
|
"Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5",
|
||||||
|
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm",
|
||||||
|
"lpt"
|
||||||
|
],
|
||||||
required=True)
|
required=True)
|
||||||
parser.add_argument('-i',
|
parser.add_argument('-i',
|
||||||
'--iterations',
|
'--iterations',
|
||||||
|
@ -170,10 +179,24 @@ if __name__ == "__main__":
|
||||||
output_path = args.output_path
|
output_path = args.output_path
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
match solver_choice:
|
||||||
|
case "Dopri5", "dopri5", "d5":
|
||||||
|
solver_choice = "Dopri5"
|
||||||
|
case "Tsit5", "tsit5", "t5":
|
||||||
|
solver_choice = "Tsit5"
|
||||||
|
case "LeapfrogMidpoint", "leapfrogmidpoint", "lfm":
|
||||||
|
solver_choice = "LeapfrogMidpoint"
|
||||||
|
case "lpt":
|
||||||
|
solver_choice = "lpt"
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt"
|
||||||
|
)
|
||||||
|
|
||||||
if args.pdims:
|
if args.pdims:
|
||||||
pdims = tuple(map(int, args.pdims.split("x")))
|
pdims = tuple(map(int, args.pdims.split("x")))
|
||||||
else:
|
else:
|
||||||
pdims = None
|
pdims = (1, 1)
|
||||||
|
|
||||||
mesh_shape = [mesh_size] * 3
|
mesh_shape = [mesh_size] * 3
|
||||||
|
|
||||||
|
@ -184,14 +207,6 @@ if __name__ == "__main__":
|
||||||
iterations,
|
iterations,
|
||||||
pdims=pdims)
|
pdims=pdims)
|
||||||
|
|
||||||
# Save the final field
|
|
||||||
if args.save_fields:
|
|
||||||
nb_gpus = jax.device_count()
|
|
||||||
field_folder = f"{output_path}/final_field/{nb_gpus}/{mesh_size}_{box_size[0]}/{solver_choice}/{halo_size}"
|
|
||||||
os.makedirs(field_folder, exist_ok=True)
|
|
||||||
np.save(f'{field_folder}/final_field_{rank}.npy',
|
|
||||||
final_field.addressable_data(0))
|
|
||||||
|
|
||||||
# Write benchmark results to CSV
|
# Write benchmark results to CSV
|
||||||
# RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD
|
# RANK SIZE MESHSIZE BOX HALO SOLVER NUM_STEPS JITTIME MIN MAX MEAN STD
|
||||||
times = np.array(times)
|
times = np.array(times)
|
||||||
|
@ -200,11 +215,31 @@ if __name__ == "__main__":
|
||||||
max_time = np.max(times) * 1000
|
max_time = np.max(times) * 1000
|
||||||
mean_time = np.mean(times) * 1000
|
mean_time = np.mean(times) * 1000
|
||||||
std_time = np.std(times) * 1000
|
std_time = np.std(times) * 1000
|
||||||
|
|
||||||
with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f:
|
with open(f"{output_path}/jax_pm_benchmark.csv", 'a') as f:
|
||||||
f.write(
|
f.write(
|
||||||
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{iterations},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n"
|
f"{rank},{size},{mesh_size},{box_size[0]},{halo_size},{solver_choice},{stats['num_steps']},{jit_in_ms},{min_time},{max_time},{mean_time},{std_time}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save the final field
|
||||||
|
nb_gpus = jax.device_count()
|
||||||
|
pdm_str = f"{pdims[0]}x{pdims[1]}"
|
||||||
|
field_folder = f"{output_path}/final_field/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/{halo_size}"
|
||||||
|
os.makedirs(field_folder, exist_ok=True)
|
||||||
|
with open(f'{field_folder}/jaxpm.log', 'w') as f:
|
||||||
|
f.write(f"Args: {args}\n")
|
||||||
|
f.write(f"JIT time: {jit_in_ms:.4f} ms\n")
|
||||||
|
f.write(f"Min time: {min_time:.4f} ms\n")
|
||||||
|
f.write(f"Max time: {max_time:.4f} ms\n")
|
||||||
|
f.write(f"Mean time: {mean_time:.4f} ms\n")
|
||||||
|
f.write(f"Std time: {std_time:.4f} ms\n")
|
||||||
|
f.write(f"Stats: {stats}\n")
|
||||||
|
if args.save_fields:
|
||||||
|
np.save(f'{field_folder}/final_field_0_{rank}.npy',
|
||||||
|
final_field.addressable_data(0))
|
||||||
|
|
||||||
print(f"Finished! Warmup time: {warmup_time:.4f} seconds")
|
print(f"Finished! Warmup time: {warmup_time:.4f} seconds")
|
||||||
print(f"mean times: {np.mean(times):.4f}")
|
print(f"mean times: {np.mean(times):.4f}")
|
||||||
print(f"Stats")
|
print(f"Stats {stats}")
|
||||||
|
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
|
||||||
|
print(f"Saving field and logs in {field_folder}")
|
||||||
|
|
167
scripts/particle_mesh.slurm
Normal file
167
scripts/particle_mesh.slurm
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
#!/bin/bash
|
||||||
|
##########################################
|
||||||
|
## SELECT EITHER tkc@a100 OR tkc@v100 ##
|
||||||
|
##########################################
|
||||||
|
#SBATCH --account tkc@a100
|
||||||
|
##########################################
|
||||||
|
#SBATCH --job-name=Particle-Mesh # nom du job
|
||||||
|
# Il est possible d'utiliser une autre partition que celle par default
|
||||||
|
# en activant l'une des 5 directives suivantes :
|
||||||
|
##########################################
|
||||||
|
## SELECT EITHER a100 or v100-32g ##
|
||||||
|
##########################################
|
||||||
|
#SBATCH -C a100
|
||||||
|
##########################################
|
||||||
|
#******************************************
|
||||||
|
##########################################
|
||||||
|
## SELECT Number of nodes and GPUs per node
|
||||||
|
## For A100 ntasks-per-node and gres=gpu should be 8
|
||||||
|
## For V100 ntasks-per-node and gres=gpu should be 4
|
||||||
|
##########################################
|
||||||
|
#SBATCH --nodes=1 # nombre de noeud
|
||||||
|
#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
|
||||||
|
#SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
|
||||||
|
##########################################
|
||||||
|
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
|
||||||
|
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
|
||||||
|
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
|
||||||
|
##########################################
|
||||||
|
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
|
||||||
|
##########################################
|
||||||
|
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
|
||||||
|
#SBATCH --hint=nomultithread # hyperthreading desactive
|
||||||
|
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
|
||||||
|
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
|
||||||
|
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
|
||||||
|
#SBATCH --qos=qos_gpu-dev
|
||||||
|
#SBATCH --exclusive # ressources dediees
|
||||||
|
|
||||||
|
# Nettoyage des modules charges en interactif et herites par defaut
|
||||||
|
num_nodes=$SLURM_JOB_NUM_NODES
|
||||||
|
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
|
||||||
|
OUTPUT_FOLDER_ARGS=1
|
||||||
|
# Calculate the number of GPUs
|
||||||
|
nb_gpus=$(( num_nodes * num_gpu_per_node))
|
||||||
|
|
||||||
|
module purge
|
||||||
|
|
||||||
|
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
|
||||||
|
# pour avoir acces aux modules compatibles avec cette partition
|
||||||
|
|
||||||
|
if [ $num_gpu_per_node -eq 8 ]; then
|
||||||
|
module load cpuarch/amd
|
||||||
|
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
|
||||||
|
else
|
||||||
|
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Chargement des modules
|
||||||
|
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
|
||||||
|
module load nvidia-nsight-systems/2024.1.1.59
|
||||||
|
|
||||||
|
echo "The number of nodes allocated for this job is: $num_nodes"
|
||||||
|
echo "The number of GPUs allocated for this job is: $nb_gpus"
|
||||||
|
|
||||||
|
export EQX_ON_ERROR=nan
|
||||||
|
export CUDA_ALLOC=1
|
||||||
|
|
||||||
|
function profile_python() {
|
||||||
|
if [ $# -lt 1 ]; then
|
||||||
|
echo "Usage: profile_python <python_script> [arguments for the script]"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
local script_name=$(basename "$1" .py)
|
||||||
|
local output_dir="prof_traces/$script_name"
|
||||||
|
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
|
||||||
|
|
||||||
|
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||||
|
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||||
|
# Remove characters '/' and '-' from folder name
|
||||||
|
args=$(echo "$args" | tr -d '/-')
|
||||||
|
output_dir="prof_traces/$script_name/$args"
|
||||||
|
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p "$output_dir"
|
||||||
|
mkdir -p "$report_dir"
|
||||||
|
|
||||||
|
srun nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
function run_python() {
|
||||||
|
if [ $# -lt 1 ]; then
|
||||||
|
echo "Usage: run_python <python_script> [arguments for the script]"
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
local script_name=$(basename "$1" .py)
|
||||||
|
local output_dir="traces/$script_name"
|
||||||
|
|
||||||
|
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||||
|
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||||
|
# Remove characters '/' and '-' from folder name
|
||||||
|
args=$(echo "$args" | tr -d '/-')
|
||||||
|
output_dir="traces/$script_name/$args"
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p "$output_dir"
|
||||||
|
|
||||||
|
srun python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Echo des commandes lancees
|
||||||
|
set -x
|
||||||
|
|
||||||
|
# Pour la partition "gpu_p5", le code doit etre compile avec les modules compatibles
|
||||||
|
# Execution du code avec binding via bind_gpu.sh : 1 GPU par tache
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
declare -A pdims_table
|
||||||
|
# Define the table
|
||||||
|
pdims_table[4]="2x2 1x4"
|
||||||
|
pdims_table[8]="2x4 1x8"
|
||||||
|
pdims_table[16]="2x8 1x16"
|
||||||
|
pdims_table[32]="4x8 1x32"
|
||||||
|
pdims_table[64]="4x16 1x64"
|
||||||
|
pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2"
|
||||||
|
pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4"
|
||||||
|
|
||||||
|
|
||||||
|
#mpch=(128 256 512 1024 2048 4096)
|
||||||
|
grid=(1024 2048 4096)
|
||||||
|
|
||||||
|
pdim="${pdims_table[$nb_gpus]}"
|
||||||
|
echo "pdims: $pdim"
|
||||||
|
|
||||||
|
# Check if pdims is not empty
|
||||||
|
if [ -z "$pdim" ]; then
|
||||||
|
echo "pdims is empty"
|
||||||
|
echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160"
|
||||||
|
echo "Number of nodes selected: $num_nodes"
|
||||||
|
echo "Number of gpus per node: $num_gpu_per_node"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
|
||||||
|
|
||||||
|
if [ $num_gpu_per_node -eq 8 ]; then
|
||||||
|
gpu_name="a100"
|
||||||
|
else
|
||||||
|
gpu_name="v100"
|
||||||
|
fi
|
||||||
|
|
||||||
|
out_dir="out/$gpu_name/$nb_gpus"
|
||||||
|
|
||||||
|
echo "Output dir is : $out_dir"
|
||||||
|
|
||||||
|
for g in ${grid[@]}; do
|
||||||
|
for p in ${pdim[@]}; do
|
||||||
|
# halo is 1/4 of the grid size
|
||||||
|
halo_size=$((g / 4))
|
||||||
|
slaunch scripts/fastpm_jaxdecomp.py -m $g -b $g -p $p -hs $halo_size -ode diffrax -o $out_dir
|
||||||
|
done
|
||||||
|
done
|
Loading…
Add table
Reference in a new issue