Add aboucaud comments

This commit is contained in:
Wassim Kabalan 2024-12-08 23:11:11 +01:00
parent adaf7d236d
commit d8c68ace7a
10 changed files with 26 additions and 777 deletions

View file

@ -29,7 +29,7 @@ jobs:
run: | run: |
sudo apt-get install -y libopenmpi-dev sudo apt-get install -y libopenmpi-dev
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install jax=0.4.35 pip install jax=0.4.35
pip install .[test] pip install .[test]
- name: Run Single Device Tests - name: Run Single Device Tests

View file

@ -11,7 +11,7 @@ Provide a modern infrastructure to support differentiable PM N-body simulations
- Any order forward and backward automatic differentiation - Any order forward and backward automatic differentiation
- Support automated batching using `vmap` - Support automated batching using `vmap`
- Compatibility with external optimizer libraries like `optax` - Compatibility with external optimizer libraries like `optax`
- Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with the latex `JAX v0.4.35` - Now fully distributable on **multi-GPU and multi-node** systems using [jaxDecomp](https://github.com/DifferentiableUniverseInitiative/jaxDecomp) working with`JAX v0.4.35`
## Open development and use ## Open development and use

View file

@ -1,270 +0,0 @@
import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
import argparse
import time
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from cupy.cuda.nvtx import RangePop, RangePush
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, Tsit5, diffeqsolve)
from hpc_plotter.timer import Timer
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
def run_simulation(mesh_shape,
box_size,
halo_size,
solver_choice,
iterations,
hlo_print,
trace,
pdims=None,
output_path="."):
@jax.jit
def simulate(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)
# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0))
# Create particles
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
if solver_choice == "Dopri5":
solver = Dopri5()
elif solver_choice == "LeapfrogMidpoint":
solver = LeapfrogMidpoint()
elif solver_choice == "Tsit5":
solver = Tsit5()
elif solver_choice == "lpt":
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
return lpt_field, {"num_steps": 0}
else:
raise ValueError(
"Invalid solver choice. Use 'Dopri5' or 'LeapfrogMidpoint'.")
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
if solver_choice == "Dopri5" or solver_choice == "Tsit5":
stepsize_controller = PIDController(rtol=1e-4, atol=1e-4)
elif solver_choice == "LeapfrogMidpoint" or solver_choice == "Euler":
stepsize_controller = ConstantStepSize()
res = diffeqsolve(term,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(t1=True),
stepsize_controller=stepsize_controller)
# Return the simulation volume at requested
state = res.ys[-1]
final_field = cic_paint_dx(state[0], halo_size=halo_size)
return final_field, res.stats
def run():
# Warm start
chrono_fun = Timer()
RangePush("warmup")
final_field, stats = chrono_fun.chrono_jit(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop()
sync_global_devices("warmup")
for i in range(iterations):
RangePush(f"sim iter {i}")
final_field, stats = chrono_fun.chrono_fun(simulate,
0.32,
0.8,
ndarray_arg=0)
RangePop()
return final_field, stats, chrono_fun
if jax.device_count() > 1:
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
with mesh:
# Warm start
final_field, stats, chrono_fun = run()
else:
final_field, stats, chrono_fun = run()
return final_field, stats, chrono_fun
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='JAX Cosmo Simulation Benchmark')
parser.add_argument('-m',
'--mesh_size',
type=int,
help='Mesh size',
required=True)
parser.add_argument('-b',
'--box_size',
type=float,
help='Box size',
required=True)
parser.add_argument('-p',
'--pdims',
type=str,
help='Processor dimensions',
default=None)
parser.add_argument(
'-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],
)
parser.add_argument('-hs',
'--halo_size',
type=int,
help='Halo size',
default=None)
parser.add_argument('-s',
'--solver',
type=str,
help='Solver',
choices=[
"Dopri5", "dopri5", "d5", "Tsit5", "tsit5", "t5",
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm",
"lpt"
],
default="lpt")
parser.add_argument('-o',
'--output_path',
type=str,
help='Output path',
default=".")
parser.add_argument('-f',
'--save_fields',
action='store_true',
help='Save fields')
parser.add_argument('-n',
'--nodes',
type=int,
help='Number of nodes',
default=1)
args = parser.parse_args()
mesh_size = args.mesh_size
box_size = [args.box_size] * 3
halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size
solver_choice = args.solver
iterations = args.iterations
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
print(f"solver choice: {solver_choice}")
match solver_choice:
case "Dopri5" | "dopri5" | "d5":
solver_choice = "Dopri5"
case "Tsit5" | "tsit5" | "t5":
solver_choice = "Tsit5"
case "LeapfrogMidpoint" | "leapfrogmidpoint" | "lfm":
solver_choice = "LeapfrogMidpoint"
case "lpt":
solver_choice = "lpt"
case _:
raise ValueError(
"Invalid solver choice. Use 'Dopri5', 'Tsit5', 'LeapfrogMidpoint' or 'lpt"
)
if args.precision == "float32":
jax.config.update("jax_enable_x64", False)
elif args.precision == "float64":
jax.config.update("jax_enable_x64", True)
if args.pdims:
pdims = tuple(map(int, args.pdims.split("x")))
else:
pdims = (1, jax.device_count())
pdm_str = f"{pdims[0]}x{pdims[1]}"
mesh_shape = [mesh_size] * 3
final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size,
halo_size, solver_choice,
iterations, pdims)
print(
f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}"
)
metadata = {
'rank': rank,
'function_name': f'JAXPM-{solver_choice}',
'precision': args.precision,
'x': str(mesh_size),
'y': str(mesh_size),
'z': str(stats["num_steps"]),
'px': str(pdims[0]),
'py': str(pdims[1]),
'backend': 'NCCL',
'nodes': str(args.nodes)
}
# Print the results to a CSV file
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
# Save the final field
nb_gpus = jax.device_count()
pdm_str = f"{pdims[0]}x{pdims[1]}"
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
os.makedirs(field_folder, exist_ok=True)
with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n")
if args.save_fields:
np.save(f'{field_folder}/final_field_0_{rank}.npy',
final_field.addressable_data(0))
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
os.makedirs(field_folder, exist_ok=True)
with open(f'{field_folder}/jaxpm.log', 'w') as f:
f.write(f"Args: {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
f.write(f"Stats: {stats}\n")
if args.save_fields:
np.save(f'{field_folder}/final_field_0_{rank}.npy',
final_field.addressable_data(0))
print(f"Finished! ")
print(f"Stats {stats}")
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
print(f"Saving field and logs in {field_folder}")

View file

@ -1,159 +0,0 @@
import os
# Change JAX GPU memory preallocation fraction
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95'
import argparse
import jax
import matplotlib.pyplot as plt
import numpy as np
from hpc_plotter.timer import Timer
from pmwd import (Configuration, Cosmology, SimpleLCDM, boltzmann, growth,
linear_modes, linear_power, lpt, nbody, scatter, white_noise)
from pmwd.pm_util import fftinv
from pmwd.spec_util import powspec
from pmwd.vis_util import simshow
# Simulation configuration
def run_pmwd_simulation(ptcl_grid_shape, ptcl_spacing, solver, iterations):
@jax.jit
def simulate(omega_m, sigma8):
conf = Configuration(ptcl_spacing,
ptcl_grid_shape=ptcl_grid_shape,
mesh_shape=1,
lpt_order=1,
a_nbody_maxstep=1 / 91)
print(conf)
print(
f'Simulating {conf.ptcl_num} particles with a {conf.mesh_shape} mesh for {conf.a_nbody_num} time steps.'
)
cosmo = Cosmology(conf,
A_s_1e9=2.0,
n_s=0.96,
Omega_m=omega_m,
Omega_b=sigma8,
h=0.7)
print(cosmo)
# Boltzmann calculation
cosmo = boltzmann(cosmo, conf)
print("Boltzmann calculation completed.")
# Generate white noise field and scale with the linear power spectrum
seed = 0
modes = white_noise(seed, conf)
modes = linear_modes(modes, cosmo, conf)
print("Linear modes generated.")
# Solve LPT at some early time
ptcl, obsvbl = lpt(modes, cosmo, conf)
print("LPT solved.")
if solver == "lfm":
# N-body time integration from LPT initial conditions
ptcl, obsvbl = jax.block_until_ready(
nbody(ptcl, obsvbl, cosmo, conf))
print("N-body time integration completed.")
# Scatter particles to mesh to get the density field
dens = scatter(ptcl, conf)
return dens
chrono_timer = Timer()
final_field = chrono_timer.chrono_jit(simulate, 0.3, 0.05)
for _ in range(iterations):
final_field = chrono_timer.chrono_fun(simulate, 0.3, 0.05)
return final_field, chrono_timer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PMWD Simulation')
parser.add_argument('-m',
'--mesh_size',
type=int,
help='Mesh size',
required=True)
parser.add_argument('-b',
'--box_size',
type=float,
help='Box size',
required=True)
parser.add_argument('-i',
'--iterations',
type=int,
help='Number of iterations',
default=10)
parser.add_argument('-o',
'--output_path',
type=str,
help='Output path',
default=".")
parser.add_argument('-f',
'--save_fields',
action='store_true',
help='Save fields')
parser.add_argument('-s',
'--solver',
type=str,
help='Solver',
choices=["lfm", "lpt"])
parser.add_argument(
'-pr',
'--precision',
type=str,
help='Precision',
choices=["float32", "float64"],
)
args = parser.parse_args()
mesh_shape = [args.mesh_size] * 3
ptcl_spacing = args.box_size / args.mesh_size
iterations = args.iterations
solver = args.solver
output_path = args.output_path
if args.precision == "float32":
jax.config.update("jax_enable_x64", False)
elif args.precision == "float64":
jax.config.update("jax_enable_x64", True)
os.makedirs(output_path, exist_ok=True)
final_field, chrono_fun = run_pmwd_simulation(mesh_shape, ptcl_spacing,
solver, iterations)
print("PMWD simulation completed.")
metadata = {
'rank': 0,
'function_name': f'PMWD-{solver}',
'precision': args.precision,
'x': str(mesh_shape[0]),
'y': str(mesh_shape[0]),
'z': str(mesh_shape[0]),
'px': "1",
'py': "1",
'backend': 'NCCL',
'nodes': "1"
}
chrono_fun.print_to_csv(f"{output_path}/pmwd.csv", **metadata)
field_folder = f"{output_path}/final_field/pmwd/1/{args.mesh_size}_{int(args.box_size)}/1x1/{args.solver}/halo_0"
os.makedirs(field_folder, exist_ok=True)
with open(f"{field_folder}/pmwd.log", "w") as f:
f.write(f"PMWD simulation completed.\n")
f.write(f"Args : {args}\n")
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
for i, time in enumerate(chrono_fun.times):
f.write(f"Time {i}: {time:.4f} ms\n")
if args.save_fields:
np.save(f"{field_folder}/final_field_0_0.npy", final_field)
print("Fields saved.")
print(f"saving to {output_path}/pmwd.csv")
print(f"saving field and logs to {field_folder}/pmwd.log")

View file

@ -1,157 +0,0 @@
#!/bin/bash
##############################################################################################################################
# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm
##############################################################################################################################
#SBATCH --job-name=Particle-Mesh # nom du job
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
#SBATCH --hint=nomultithread # hyperthreading desactive
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
#SBATCH --error=%x_%N_a100.err # nom du fichier d'erreur (ici commun avec la sortie)
#SBATCH --exclusive # ressources dediees
##SBATCH --qos=qos_gpu-dev
# Nettoyage des modules charges en interactif et herites par defaut
num_nodes=$SLURM_JOB_NUM_NODES
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
OUTPUT_FOLDER_ARGS=1
# Calculate the number of GPUs
nb_gpus=$(( num_nodes * num_gpu_per_node))
module purge
echo "Job partition: $SLURM_JOB_PARTITION"
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
# pour avoir acces aux modules compatibles avec cette partition
if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then
module load cpuarch/amd
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
gpu_name=a100
else
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
gpu_name=v100
fi
# Chargement des modules
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus"
export EQX_ON_ERROR=nan
export ENABLE_PERFO_STEP=NVTX
export MPI4JAX_USE_CUDA_MPI=1
function profile_python() {
if [ $# -lt 1 ]; then
echo "Usage: profile_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name"
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args"
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
mkdir -p "$report_dir"
srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
function run_python() {
if [ $# -lt 1 ]; then
echo "Usage: run_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="traces/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
# run or profile
function slaunch() {
run_python "$@"
}
function plaunch() {
profile_python "$@"
}
# Echo des commandes lancees
set -x
# Pour ne pas utiliser le /tmp
export TMPDIR=$JOBSCRATCH
# Pour contourner un bogue dans les versions actuelles de Nsight Systems
# il est également nécessaire de créer un lien symbolique permettant de
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
ln -s $JOBSCRATCH /tmp/nvidia
declare -A pdims_table
# Define the table
pdims_table[1]="1x1"
pdims_table[4]="2x2 1x4 4x1"
pdims_table[8]="2x4 1x8 8x1 4x2"
pdims_table[16]="4x4 1x16 16x1"
pdims_table[32]="4x8 8x4 1x32 32x1"
pdims_table[64]="8x8 16x4 1x64 64x1"
pdims_table[128]="8x16 16x8 1x128 128x1"
pdims_table[256]="16x16 1x256 256x1"
# mpch=(128 256 512 1024 2048 4096)
grid=(256 512 1024 2048 4096 8192)
precisions=(float32 float64)
pdim="${pdims_table[$nb_gpus]}"
solvers=(lpt lfm)
echo "pdims: $pdim"
# Check if pdims is not empty
if [ -z "$pdim" ]; then
echo "pdims is empty"
echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160"
echo "Number of nodes selected: $num_nodes"
echo "Number of gpus per node: $num_gpu_per_node"
exit 1
fi
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
out_dir="pm_prof/$gpu_name/$nb_gpus"
trace_dir="traces/$gpu_name/$nb_gpus/bench_pm"
echo "Output dir is : $out_dir"
echo "Trace dir is : $trace_dir"
for pr in "${precisions[@]}"; do
for g in "${grid[@]}"; do
for solver in "${solvers[@]}"; do
for p in $pdim; do
halo_size=$((g / 4))
slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes
done
done
# delete crash core dump files
rm -f core.python.*
done
done

View file

@ -1,147 +0,0 @@
#!/bin/bash
##############################################################################################################################
# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm
##############################################################################################################################
#SBATCH --job-name=Particle-Mesh # nom du job
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
#SBATCH --hint=nomultithread # hyperthreading desactive
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
#SBATCH --exclusive # ressources dediees
##SBATCH --qos=qos_gpu-dev
# Nettoyage des modules charges en interactif et herites par defaut
num_nodes=$SLURM_JOB_NUM_NODES
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
OUTPUT_FOLDER_ARGS=1
# Calculate the number of GPUs
nb_gpus=$(( num_nodes * num_gpu_per_node))
module purge
echo "Job partition: $SLURM_JOB_PARTITION"
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
# pour avoir acces aux modules compatibles avec cette partition
if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then
module load cpuarch/amd
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
gpu_name=a100
else
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
gpu_name=v100
fi
# Chargement des modules
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus"
export EQX_ON_ERROR=nan
export ENABLE_PERFO_STEP=NVTX
export MPI4JAX_USE_CUDA_MPI=1
function profile_python() {
if [ $# -lt 1 ]; then
echo "Usage: profile_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name"
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args"
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
mkdir -p "$report_dir"
srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
function run_python() {
if [ $# -lt 1 ]; then
echo "Usage: run_python <python_script> [arguments for the script]"
return 1
fi
local script_name=$(basename "$1" .py)
local output_dir="traces/$gpu_name/$nb_gpus/$script_name"
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
local args=$(echo "${@:2}" | tr ' ' '_')
# Remove characters '/' and '-' from folder name
args=$(echo "$args" | tr -d '/-')
output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args"
fi
mkdir -p "$output_dir"
srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
}
# run or profile
function slaunch() {
run_python "$@"
}
function plaunch() {
profile_python "$@"
}
# Echo des commandes lancees
set -x
# Pour ne pas utiliser le /tmp
export TMPDIR=$JOBSCRATCH
# Pour contourner un bogue dans les versions actuelles de Nsight Systems
# il est également nécessaire de créer un lien symbolique permettant de
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
ln -s $JOBSCRATCH /tmp/nvidia
# mpch=(128 256 512 1024 2048 4096)
grid=(256 512 1024 2048 4096 8192)
precisions=(float32 float64)
solvers=(lpt lfm)
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
if [ $num_gpu_per_node -eq 8 ]; then
gpu_name="a100"
else
gpu_name="v100"
fi
out_dir="pm_prof/$gpu_name/$nb_gpus"
trace_dir="traces/$gpu_name/$nb_gpus/bench_pmwd"
echo "Output dir is : $out_dir"
echo "Trace dir is : $trace_dir"
for pr in "${precisions[@]}"; do
for g in "${grid[@]}"; do
for solver in "${solvers[@]}"; do
slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f
done
# delete crash core dump files
rm -f core.python.*
done
done
# # zip the output files and traces
# tar -czvf $out_dir.tar.gz $out_dir
# tar -czvf $trace_dir.tar.gz $trace_dir
# # remove the output files and traces
# rm -rf $out_dir $trace_dir
#

View file

@ -1,19 +0,0 @@
#!/bin/bash
# Run all slurms jobs
nodes_v100=(1 2 4 8 16 32)
nodes_a100=(1 2 4 8 16 32)
for n in ${nodes_v100[@]}; do
sbatch --account=tkc@v100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C v100-32g --job-name=JAXPM-$n-N-v100 particle_mesh.slurm
done
for n in ${nodes_a100[@]}; do
sbatch --account=tkc@a100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C a100 --job-name=JAXPM-$n-N-a100 particle_mesh.slurm
done
# single GPUs
sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=JAXPM-1GPU-V100 particle_mesh.slurm
sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=JAXPM-1GPU-A100 particle_mesh.slurm
sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=PMWD-1GPU-v100 pmwd_pm.slurm
sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=PMWD-1GPU-a100 pmwd_pm.slurm

View file

@ -10,13 +10,13 @@ import jax.numpy as jnp
import jaxdecomp import jaxdecomp
from jax import lax from jax import lax
from jax.experimental.shard_map import shard_map from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh from jax.sharding import AbstractMesh, Mesh
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
def autoshmap( def autoshmap(
f: Callable, f: Callable,
gpu_mesh: Mesh | None, gpu_mesh: Mesh | AbstractMesh | None,
in_specs: Specs, in_specs: Specs,
out_specs: Specs, out_specs: Specs,
check_rep: bool = False, check_rep: bool = False,
@ -122,7 +122,7 @@ def get_local_shape(mesh_shape, sharding=None):
] ]
def __axis_names(spec): def _axis_names(spec):
if len(spec) == 1: if len(spec) == 1:
x_axis, = spec x_axis, = spec
y_axis = None y_axis = None
@ -147,7 +147,7 @@ def uniform_particles(mesh_shape, sharding=None):
if gpu_mesh is not None and not (gpu_mesh.empty): if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding) local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec) x_axis, y_axis, single_axis = _axis_names(spec)
def particles(): def particles():
x_indx = lax.axis_index(x_axis) x_indx = lax.axis_index(x_axis)
@ -178,7 +178,7 @@ def normal_field(mesh_shape, seed, sharding=None):
# to make the code work both in multi host and single controller we can do this trick # to make the code work both in multi host and single controller we can do this trick
keys = jax.random.split(seed, size) keys = jax.random.split(seed, size)
spec = sharding.spec spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec) x_axis, y_axis, single_axis = _axis_names(spec)
def normal(keys, shape, dtype): def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis) idx = lax.axis_index(x_axis)

View file

@ -3,6 +3,7 @@ from functools import partial
import jax import jax
import jax.lax as lax import jax.lax as lax
import jax.numpy as jnp import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange, from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
@ -11,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter from jaxpm.painting_utils import gather, scatter
def cic_paint_impl(grid_mesh, positions, weight=None): def _cic_paint_impl(grid_mesh, positions, weight=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3] displacement field: [nx, ny, nz, 3]
@ -54,9 +55,9 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding) halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if sharding is not None else P() spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(cic_paint_impl, grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()), in_specs=(spec, spec, P()),
out_specs=spec)(grid_mesh, positions, weight) out_specs=spec)(grid_mesh, positions, weight)
@ -68,7 +69,7 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
return grid_mesh return grid_mesh
def cic_read_impl(grid_mesh, positions): def _cic_read_impl(grid_mesh, positions):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [nx,ny,nz, 3] positions: [nx,ny,nz, 3]
@ -110,10 +111,10 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True)) halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if sharding is not None else P() spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacement = autoshmap(cic_read_impl, displacement = autoshmap(_cic_read_impl,
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=(spec, spec), in_specs=(spec, spec),
out_specs=spec)(grid_mesh, positions) out_specs=spec)(grid_mesh, positions)
@ -150,7 +151,7 @@ def cic_paint_2d(mesh, positions, weight):
return mesh return mesh
def cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_x, _ = halo_size[0] halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
@ -187,9 +188,9 @@ def cic_paint_dx(displacements,
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if sharding is not None else P() spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(partial(cic_paint_dx_impl, grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size, halo_size=halo_size,
weight=weight, weight=weight,
chunk_size=chunk_size), chunk_size=chunk_size),
@ -204,7 +205,7 @@ def cic_paint_dx(displacements,
return grid_mesh return grid_mesh
def cic_read_dx_impl(grid_mesh, disp, halo_size): def _cic_read_dx_impl(grid_mesh, disp, halo_size):
halo_x, _ = halo_size[0] halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
@ -233,9 +234,9 @@ def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True)) halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if sharding is not None else P() spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=(spec), in_specs=(spec),
out_specs=spec)(grid_mesh, disp) out_specs=spec)(grid_mesh, disp)

View file

@ -5,8 +5,8 @@ initialize_distributed() # ignore : E402
import jax # noqa : E402 import jax # noqa : E402
import jax.numpy as jnp # noqa : E402 import jax.numpy as jnp # noqa : E402
import pytest # noqa : E402 import pytest # noqa : E402
from diffrax import (Dopri5, ODETerm, PIDController, SaveAt, # noqa : E402 from diffrax import SaveAt # noqa : E402
diffeqsolve) from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
from helpers import MSE # noqa : E402 from helpers import MSE # noqa : E402
from jax import lax # noqa : E402 from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402