mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
Update for jaxDecomp pure JAX
This commit is contained in:
parent
831291c1f9
commit
2ea05a1cd6
9 changed files with 214 additions and 532 deletions
|
@ -25,7 +25,7 @@ from jax.sharding import PartitionSpec as P
|
|||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
from jax import make_jaxpr
|
||||
|
||||
|
||||
def run_simulation(mesh_shape,
|
||||
|
@ -33,7 +33,10 @@ def run_simulation(mesh_shape,
|
|||
halo_size,
|
||||
solver_choice,
|
||||
iterations,
|
||||
pdims=None):
|
||||
hlo_print,
|
||||
trace,
|
||||
pdims=None,
|
||||
output_path="."):
|
||||
|
||||
@jax.jit
|
||||
def simulate(omega_c, sigma8):
|
||||
|
@ -60,7 +63,6 @@ def run_simulation(mesh_shape,
|
|||
solver = Tsit5()
|
||||
elif solver_choice == "lpt":
|
||||
lpt_field = cic_paint_dx(dx, halo_size=halo_size)
|
||||
print(f"TYPE of lpt_field: {type(lpt_field)}")
|
||||
return lpt_field, {"num_steps": 0}
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -92,16 +94,37 @@ def run_simulation(mesh_shape,
|
|||
|
||||
def run():
|
||||
# Warm start
|
||||
chrono_fun = Timer()
|
||||
RangePush("warmup")
|
||||
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||
RangePop()
|
||||
sync_global_devices("warmup")
|
||||
for i in range(iterations):
|
||||
RangePush(f"sim iter {i}")
|
||||
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||
RangePop()
|
||||
return final_field, stats, chrono_fun
|
||||
if hlo_print:
|
||||
jaxpr = make_jaxpr(simulate)(0.32, 0.8)
|
||||
lowered = jax.jit(simulate).lower(0.32, 0.8)
|
||||
lower_as_text = lowered.as_text()
|
||||
compiled = lowered.compile()
|
||||
compiled_again = jax.jit(simulate).lower(0.32, 0.8).compile()
|
||||
return jaxpr , compiled , compiled_again
|
||||
elif trace:
|
||||
jit_output = f"{output_path}/jit_trace"
|
||||
first_run_output = f"{output_path}/first_run_trace"
|
||||
second_run_output = f"{output_path}/second_run_trace"
|
||||
with jax.profiler.trace(jit_output , create_perfetto_trace=True):
|
||||
final_field, stats = simulate(0.32, 0.8)
|
||||
final_field.block_until_ready()
|
||||
with jax.profiler.trace(first_run_output , create_perfetto_trace=True):
|
||||
final_field, stats = simulate(0.32, 0.8)
|
||||
final_field.block_until_ready()
|
||||
with jax.profiler.trace(second_run_output , create_perfetto_trace=True):
|
||||
final_field, stats = simulate(0.32, 0.8)
|
||||
final_field.block_until_ready()
|
||||
else:
|
||||
chrono_fun = Timer()
|
||||
RangePush("warmup")
|
||||
final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||
RangePop()
|
||||
sync_global_devices("warmup")
|
||||
for i in range(iterations):
|
||||
RangePush(f"sim iter {i}")
|
||||
final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
|
||||
RangePop()
|
||||
return final_field, stats, chrono_fun
|
||||
|
||||
if jax.device_count() > 1:
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
|
@ -143,7 +166,7 @@ if __name__ == "__main__":
|
|||
'--halo_size',
|
||||
type=int,
|
||||
help='Halo size',
|
||||
required=True)
|
||||
default=None)
|
||||
parser.add_argument('-s',
|
||||
'--solver',
|
||||
type=str,
|
||||
|
@ -153,12 +176,7 @@ if __name__ == "__main__":
|
|||
"LeapfrogMidpoint", "leapfrogmidpoint", "lfm",
|
||||
"lpt"
|
||||
],
|
||||
required=True)
|
||||
parser.add_argument('-i',
|
||||
'--iterations',
|
||||
type=int,
|
||||
help='Number of iterations',
|
||||
default=10)
|
||||
default="lpt")
|
||||
parser.add_argument('-o',
|
||||
'--output_path',
|
||||
type=str,
|
||||
|
@ -173,16 +191,33 @@ if __name__ == "__main__":
|
|||
type=int,
|
||||
help='Number of nodes',
|
||||
default=1)
|
||||
parser.add_argument('-i',
|
||||
'--iterations',
|
||||
type=int,
|
||||
help='Number of iterations',
|
||||
default=10)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument('-hlo',
|
||||
'--hlo_print',
|
||||
action='store_true',
|
||||
help='Print hlo generated by XLA')
|
||||
group.add_argument('-t',
|
||||
'--trace',
|
||||
action='store_true',
|
||||
help='Profile using tensorboard')
|
||||
|
||||
args = parser.parse_args()
|
||||
mesh_size = args.mesh_size
|
||||
box_size = [args.box_size] * 3
|
||||
halo_size = args.halo_size
|
||||
halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size
|
||||
solver_choice = args.solver
|
||||
iterations = args.iterations
|
||||
output_path = args.output_path
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
hlo_print = args.hlo_print
|
||||
trace = args.trace
|
||||
nb_gpus = jax.device_count()
|
||||
|
||||
print(f"solver choice: {solver_choice}")
|
||||
match solver_choice:
|
||||
case "Dopri5" | "dopri5"| "d5":
|
||||
|
@ -205,45 +240,81 @@ if __name__ == "__main__":
|
|||
if args.pdims:
|
||||
pdims = tuple(map(int, args.pdims.split("x")))
|
||||
else:
|
||||
pdims = (1, 1)
|
||||
pdims = (1, jax.device_count())
|
||||
pdm_str = f"{pdims[0]}x{pdims[1]}"
|
||||
|
||||
mesh_shape = [mesh_size] * 3
|
||||
|
||||
final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims)
|
||||
|
||||
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
|
||||
if trace:
|
||||
trace_folder = f"{output_path}/profiling/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
|
||||
os.makedirs(trace_folder, exist_ok=True)
|
||||
run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, trace_folder)
|
||||
print(f"Profiling done! Check {trace_folder}")
|
||||
elif hlo_print:
|
||||
hlo_folder = f"{output_path}/hlo/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
|
||||
os.makedirs(hlo_folder, exist_ok=True)
|
||||
jaxpr , compiled , compiled2 = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, hlo_folder)
|
||||
print(f"type of memory analysis {type(compiled.memory_analysis())}")
|
||||
print(f"memory analysis {compiled.memory_analysis()}")
|
||||
print(f"memory analysis again {compiled2.memory_analysis()}")
|
||||
jax.tree.map(lambda x: print(x), compiled.memory_analysis())
|
||||
with open(f'{hlo_folder}/hlo_jaxpm.md', 'w') as f:
|
||||
f.write(f"# JAXPM HLO\n")
|
||||
f.write(f"## Args: {args}\n")
|
||||
f.write(f"## JAXPR is \n")
|
||||
f.write(f'---\n')
|
||||
f.write(f"{jaxpr}\n")
|
||||
f.write(f'---\n')
|
||||
f.write(f"Lowered as text is \n")
|
||||
f.write(f'---\n')
|
||||
# f.write(f"{lower_as_text}\n")
|
||||
f.write(f'---\n')
|
||||
f.write(f"Compiled is \n")
|
||||
f.write(f'---\n')
|
||||
f.write(f"{compiled.as_text()}\n")
|
||||
f.write(f"Cost analysis is \n")
|
||||
f.write(f'---\n')
|
||||
f.write(f"{compiled.cost_analysis()[0]['flops']}\n")
|
||||
f.write(f'---\n')
|
||||
f.write(f"Memory analysis is \n")
|
||||
f.write(f'---\n')
|
||||
f.write(f"{compiled.memory_analysis()}\n")
|
||||
f.write(f'---\n')
|
||||
|
||||
metadata = {
|
||||
'rank': rank,
|
||||
'function_name': f'JAXPM-{solver_choice}',
|
||||
'precision': args.precision,
|
||||
'x': str(mesh_size),
|
||||
'y': str(mesh_size),
|
||||
'z': str(stats["num_steps"]),
|
||||
'px': str(pdims[0]),
|
||||
'py': str(pdims[1]),
|
||||
'backend': 'NCCL',
|
||||
'nodes': str(args.nodes)
|
||||
}
|
||||
# Print the results to a CSV file
|
||||
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
|
||||
print(f"Saved HLO to {hlo_folder}")
|
||||
else:
|
||||
final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, output_path)
|
||||
print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
|
||||
metadata = {
|
||||
'rank': rank,
|
||||
'function_name': f'JAXPM-{solver_choice}',
|
||||
'precision': args.precision,
|
||||
'x': str(mesh_size),
|
||||
'y': str(mesh_size),
|
||||
'z': str(stats["num_steps"]),
|
||||
'px': str(pdims[0]),
|
||||
'py': str(pdims[1]),
|
||||
'backend': 'NCCL',
|
||||
'nodes': str(args.nodes)
|
||||
}
|
||||
# Print the results to a CSV file
|
||||
chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
|
||||
|
||||
# Save the final field
|
||||
nb_gpus = jax.device_count()
|
||||
pdm_str = f"{pdims[0]}x{pdims[1]}"
|
||||
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
|
||||
os.makedirs(field_folder, exist_ok=True)
|
||||
with open(f'{field_folder}/jaxpm.log', 'w') as f:
|
||||
f.write(f"Args: {args}\n")
|
||||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
|
||||
for i , time in enumerate(chrono_fun.times):
|
||||
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||
f.write(f"Stats: {stats}\n")
|
||||
if args.save_fields:
|
||||
np.save(f'{field_folder}/final_field_0_{rank}.npy',
|
||||
final_field.addressable_data(0))
|
||||
# Save the final field
|
||||
|
||||
print(f"Finished! ")
|
||||
print(f"Stats {stats}")
|
||||
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
|
||||
print(f"Saving field and logs in {field_folder}")
|
||||
field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
|
||||
os.makedirs(field_folder, exist_ok=True)
|
||||
with open(f'{field_folder}/jaxpm.log', 'w') as f:
|
||||
f.write(f"Args: {args}\n")
|
||||
f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
|
||||
for i , time in enumerate(chrono_fun.times):
|
||||
f.write(f"Time {i}: {time:.4f} ms\n")
|
||||
f.write(f"Stats: {stats}\n")
|
||||
if args.save_fields:
|
||||
np.save(f'{field_folder}/final_field_0_{rank}.npy',
|
||||
final_field.addressable_data(0))
|
||||
|
||||
print(f"Finished! ")
|
||||
print(f"Stats {stats}")
|
||||
print(f"Saving to {output_path}/jax_pm_benchmark.csv")
|
||||
print(f"Saving field and logs in {field_folder}")
|
||||
|
|
|
@ -1,40 +1,15 @@
|
|||
#!/bin/bash
|
||||
##########################################
|
||||
## SELECT EITHER tkc@a100 OR tkc@v100 ##
|
||||
##########################################
|
||||
#SBATCH --account tkc@a100
|
||||
##########################################
|
||||
#SBATCH --job-name=1N-FFT-Mesh # nom du job
|
||||
# Il est possible d'utiliser une autre partition que celle par default
|
||||
# en activant l'une des 5 directives suivantes :
|
||||
##########################################
|
||||
## SELECT EITHER a100 or v100-32g ##
|
||||
##########################################
|
||||
#SBATCH -C a100
|
||||
##########################################
|
||||
#******************************************
|
||||
##########################################
|
||||
## SELECT Number of nodes and GPUs per node
|
||||
## For A100 ntasks-per-node and gres=gpu should be 8
|
||||
## For V100 ntasks-per-node and gres=gpu should be 4
|
||||
##########################################
|
||||
#SBATCH --nodes=1 # nombre de noeud
|
||||
#SBATCH --ntasks-per-node=8 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
|
||||
#SBATCH --gres=gpu:8 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
|
||||
##########################################
|
||||
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
|
||||
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
|
||||
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
|
||||
##########################################
|
||||
##############################################################################################################################
|
||||
# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm
|
||||
##############################################################################################################################
|
||||
#SBATCH --job-name=Particle-Mesh # nom du job
|
||||
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
|
||||
##########################################
|
||||
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
|
||||
#SBATCH --hint=nomultithread # hyperthreading desactive
|
||||
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
|
||||
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
|
||||
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
|
||||
#SBATCH --error=%x_%N_a100.err # nom du fichier d'erreur (ici commun avec la sortie)
|
||||
#SBATCH --exclusive # ressources dediees
|
||||
##SBATCH --qos=qos_gpu-dev
|
||||
## SBATCH --exclusive # ressources dediees
|
||||
# Nettoyage des modules charges en interactif et herites par defaut
|
||||
num_nodes=$SLURM_JOB_NUM_NODES
|
||||
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
|
||||
|
@ -44,12 +19,11 @@ nb_gpus=$(( num_nodes * num_gpu_per_node))
|
|||
|
||||
module purge
|
||||
|
||||
echo "Job constraint: $SLURM_JOB_CONSTRAINT"
|
||||
echo "Job partition: $SLURM_JOB_PARTITION"
|
||||
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
|
||||
# pour avoir acces aux modules compatibles avec cette partition
|
||||
|
||||
if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
|
||||
if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then
|
||||
module load cpuarch/amd
|
||||
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
|
||||
gpu_name=a100
|
||||
|
@ -66,8 +40,10 @@ module load nvidia-nsight-systems/2024.1.1.59
|
|||
echo "The number of nodes allocated for this job is: $num_nodes"
|
||||
echo "The number of GPUs allocated for this job is: $nb_gpus"
|
||||
|
||||
export EQX_ON_ERROR=nan
|
||||
export ENABLE_PERFO_STEP=NVTX
|
||||
export MPI4JAX_USE_CUDA_MPI=1
|
||||
|
||||
function profile_python() {
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: profile_python <python_script> [arguments for the script]"
|
||||
|
@ -75,14 +51,14 @@ function profile_python() {
|
|||
fi
|
||||
|
||||
local script_name=$(basename "$1" .py)
|
||||
local output_dir="prof_traces/$script_name"
|
||||
local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name"
|
||||
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
|
||||
|
||||
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||
# Remove characters '/' and '-' from folder name
|
||||
args=$(echo "$args" | tr -d '/-')
|
||||
output_dir="prof_traces/$script_name/$args"
|
||||
output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args"
|
||||
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
|
||||
fi
|
||||
|
||||
|
@ -99,13 +75,13 @@ function run_python() {
|
|||
fi
|
||||
|
||||
local script_name=$(basename "$1" .py)
|
||||
local output_dir="traces/$script_name"
|
||||
local output_dir="traces/$gpu_name/$nb_gpus/$script_name"
|
||||
|
||||
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||
# Remove characters '/' and '-' from folder name
|
||||
args=$(echo "$args" | tr -d '/-')
|
||||
output_dir="traces/$script_name/$args"
|
||||
output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args"
|
||||
fi
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
|
@ -142,12 +118,11 @@ pdims_table[8]="2x4 1x8 8x1 4x2"
|
|||
pdims_table[16]="4x4 1x16 16x1"
|
||||
pdims_table[32]="4x8 8x4 1x32 32x1"
|
||||
pdims_table[64]="8x8 16x4 1x64 64x1"
|
||||
pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2"
|
||||
pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4"
|
||||
|
||||
pdims_table[128]="8x16 16x8 1x128 128x1"
|
||||
pdims_table[256]="16x16 1x256 256x1"
|
||||
|
||||
# mpch=(128 256 512 1024 2048 4096)
|
||||
grid=(256 512 1024 2048 4096)
|
||||
grid=(256 512 1024 2048 4096 8192)
|
||||
precisions=(float32 float64)
|
||||
pdim="${pdims_table[$nb_gpus]}"
|
||||
solvers=(lpt lfm)
|
||||
|
@ -164,8 +139,9 @@ fi
|
|||
|
||||
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
|
||||
out_dir="pm_prof/$gpu_name/$nb_gpus"
|
||||
|
||||
echo "Output dir is : $out_dir"
|
||||
trace_dir="traces/$gpu_name/$nb_gpus/bench_pm"
|
||||
echo "Output dir is : $out_dir"
|
||||
echo "Trace dir is : $trace_dir"
|
||||
|
||||
for pr in "${precisions[@]}"; do
|
||||
for g in "${grid[@]}"; do
|
||||
|
@ -175,9 +151,13 @@ for pr in "${precisions[@]}"; do
|
|||
slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes
|
||||
done
|
||||
done
|
||||
# delete crash core dump files
|
||||
rm -f core.python.*
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
|
||||
# # zip the output files and traces
|
||||
# tar -czvf $out_dir.tar.gz $out_dir
|
||||
# tar -czvf $trace_dir.tar.gz $trace_dir
|
||||
# # remove the output files and traces
|
||||
# rm -rf $out_dir $trace_dir
|
|
@ -1,184 +0,0 @@
|
|||
#!/bin/bash
|
||||
##########################################
|
||||
## SELECT EITHER tkc@a100 OR tkc@v100 ##
|
||||
##########################################
|
||||
#SBATCH --account tkc@v100
|
||||
##########################################
|
||||
#SBATCH --job-name=V100Particle-Mesh # nom du job
|
||||
# Il est possible d'utiliser une autre partition que celle par default
|
||||
# en activant l'une des 5 directives suivantes :
|
||||
##########################################
|
||||
## SELECT EITHER a100 or v100-32g ##
|
||||
##########################################
|
||||
#SBATCH -C v100-32g
|
||||
##########################################
|
||||
#******************************************
|
||||
##########################################
|
||||
## SELECT Number of nodes and GPUs per node
|
||||
## For A100 ntasks-per-node and gres=gpu should be 8
|
||||
## For V100 ntasks-per-node and gres=gpu should be 4
|
||||
##########################################
|
||||
#SBATCH --nodes=1 # nombre de noeud
|
||||
#SBATCH --ntasks-per-node=4 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
|
||||
#SBATCH --gres=gpu:4 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
|
||||
##########################################
|
||||
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
|
||||
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
|
||||
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
|
||||
##########################################
|
||||
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
|
||||
##########################################
|
||||
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
|
||||
#SBATCH --hint=nomultithread # hyperthreading desactive
|
||||
#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS)
|
||||
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
|
||||
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
|
||||
#SBATCH --qos=qos_gpu-dev
|
||||
#SBATCH --exclusive # ressources dediees
|
||||
# Nettoyage des modules charges en interactif et herites par defaut
|
||||
num_nodes=$SLURM_JOB_NUM_NODES
|
||||
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
|
||||
OUTPUT_FOLDER_ARGS=1
|
||||
# Calculate the number of GPUs
|
||||
nb_gpus=$(( num_nodes * num_gpu_per_node))
|
||||
|
||||
module purge
|
||||
|
||||
echo "Job constraint: $SLURM_JOB_CONSTRAINT"
|
||||
echo "Job partition: $SLURM_JOB_PARTITION"
|
||||
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
|
||||
# pour avoir acces aux modules compatibles avec cette partition
|
||||
|
||||
if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
|
||||
module load cpuarch/amd
|
||||
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
|
||||
gpu_name=a100
|
||||
else
|
||||
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
|
||||
gpu_name=v100
|
||||
fi
|
||||
|
||||
# Chargement des modules
|
||||
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
|
||||
module load nvidia-nsight-systems/2024.1.1.59
|
||||
|
||||
|
||||
echo "The number of nodes allocated for this job is: $num_nodes"
|
||||
echo "The number of GPUs allocated for this job is: $nb_gpus"
|
||||
|
||||
export EQX_ON_ERROR=nan
|
||||
export CUDA_ALLOC=1
|
||||
export ENABLE_PERFO_STEP=NVTX
|
||||
export MPI4JAX_USE_CUDA_MPI=1
|
||||
function profile_python() {
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: profile_python <python_script> [arguments for the script]"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local script_name=$(basename "$1" .py)
|
||||
local output_dir="prof_traces/$script_name"
|
||||
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
|
||||
|
||||
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||
# Remove characters '/' and '-' from folder name
|
||||
args=$(echo "$args" | tr -d '/-')
|
||||
output_dir="prof_traces/$script_name/$args"
|
||||
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
|
||||
fi
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
mkdir -p "$report_dir"
|
||||
|
||||
srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
|
||||
}
|
||||
|
||||
function run_python() {
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: run_python <python_script> [arguments for the script]"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local script_name=$(basename "$1" .py)
|
||||
local output_dir="traces/$script_name"
|
||||
|
||||
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||
# Remove characters '/' and '-' from folder name
|
||||
args=$(echo "$args" | tr -d '/-')
|
||||
output_dir="traces/$script_name/$args"
|
||||
fi
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
|
||||
srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
|
||||
}
|
||||
|
||||
|
||||
# run or profile
|
||||
|
||||
function slaunch() {
|
||||
run_python "$@"
|
||||
}
|
||||
|
||||
function plaunch() {
|
||||
profile_python "$@"
|
||||
}
|
||||
|
||||
# Echo des commandes lancees
|
||||
set -x
|
||||
|
||||
# Pour ne pas utiliser le /tmp
|
||||
export TMPDIR=$JOBSCRATCH
|
||||
# Pour contourner un bogue dans les versions actuelles de Nsight Systems
|
||||
# il est également nécessaire de créer un lien symbolique permettant de
|
||||
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
|
||||
ln -s $JOBSCRATCH /tmp/nvidia
|
||||
|
||||
declare -A pdims_table
|
||||
# Define the table
|
||||
pdims_table[1]="1x1"
|
||||
pdims_table[4]="2x2 1x4 4x1"
|
||||
pdims_table[8]="2x4 1x8 8x1 4x2"
|
||||
pdims_table[16]="4x4 1x16 16x1"
|
||||
pdims_table[32]="4x8 8x4 1x32 32x1"
|
||||
pdims_table[64]="8x8 16x4 1x64 64x1"
|
||||
pdims_table[128]="8x16 16x8 4x32 1x128 128x1"
|
||||
pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4"
|
||||
|
||||
|
||||
# mpch=(128 256 512 1024 2048 4096)
|
||||
grid=(256 512 1024 2048 4096)
|
||||
precisions=(float32 float64)
|
||||
pdim="${pdims_table[$nb_gpus]}"
|
||||
solvers=(lpt lfm)
|
||||
echo "pdims: $pdim"
|
||||
|
||||
# Check if pdims is not empty
|
||||
if [ -z "$pdim" ]; then
|
||||
echo "pdims is empty"
|
||||
echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160"
|
||||
echo "Number of nodes selected: $num_nodes"
|
||||
echo "Number of gpus per node: $num_gpu_per_node"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
|
||||
out_dir="pm_prof/$gpu_name/$nb_gpus"
|
||||
|
||||
echo "Output dir is : $out_dir"
|
||||
|
||||
for pr in "${precisions[@]}"; do
|
||||
for g in "${grid[@]}"; do
|
||||
for solver in "${solvers[@]}"; do
|
||||
for p in $pdim; do
|
||||
halo_size=$((g / 4))
|
||||
slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
|
@ -1,165 +0,0 @@
|
|||
#!/bin/bash
|
||||
##########################################
|
||||
## SELECT EITHER tkc@a100 OR tkc@v100 ##
|
||||
##########################################
|
||||
#SBATCH --account tkc@a100
|
||||
##########################################
|
||||
#SBATCH --job-name=1N-FFT-Mesh # nom du job
|
||||
# Il est possible d'utiliser une autre partition que celle par default
|
||||
# en activant l'une des 5 directives suivantes :
|
||||
##########################################
|
||||
## SELECT EITHER a100 or v100-32g ##
|
||||
##########################################
|
||||
#SBATCH -C a100
|
||||
##########################################
|
||||
#******************************************
|
||||
##########################################
|
||||
## SELECT Number of nodes and GPUs per node
|
||||
## For A100 ntasks-per-node and gres=gpu should be 8
|
||||
## For V100 ntasks-per-node and gres=gpu should be 4
|
||||
##########################################
|
||||
#SBATCH --nodes=1 # nombre de noeud
|
||||
#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
|
||||
#SBATCH --gres=gpu:1 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
|
||||
##########################################
|
||||
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
|
||||
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
|
||||
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
|
||||
##########################################
|
||||
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
|
||||
##########################################
|
||||
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
|
||||
#SBATCH --hint=nomultithread # hyperthreading desactive
|
||||
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
|
||||
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
|
||||
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
|
||||
##SBATCH --qos=qos_gpu-dev
|
||||
## SBATCH --exclusive # ressources dediees
|
||||
# Nettoyage des modules charges en interactif et herites par defaut
|
||||
num_nodes=$SLURM_JOB_NUM_NODES
|
||||
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
|
||||
OUTPUT_FOLDER_ARGS=1
|
||||
# Calculate the number of GPUs
|
||||
nb_gpus=$(( num_nodes * num_gpu_per_node))
|
||||
|
||||
module purge
|
||||
|
||||
echo "Job constraint: $SLURM_JOB_CONSTRAINT"
|
||||
echo "Job partition: $SLURM_JOB_PARTITION"
|
||||
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
|
||||
# pour avoir acces aux modules compatibles avec cette partition
|
||||
|
||||
if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
|
||||
module load cpuarch/amd
|
||||
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
|
||||
gpu_name=a100
|
||||
else
|
||||
source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
|
||||
gpu_name=v100
|
||||
fi
|
||||
|
||||
# Chargement des modules
|
||||
module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
|
||||
module load nvidia-nsight-systems/2024.1.1.59
|
||||
|
||||
|
||||
echo "The number of nodes allocated for this job is: $num_nodes"
|
||||
echo "The number of GPUs allocated for this job is: $nb_gpus"
|
||||
|
||||
export EQX_ON_ERROR=nan
|
||||
export CUDA_ALLOC=1
|
||||
export ENABLE_PERFO_STEP=NVTX
|
||||
export MPI4JAX_USE_CUDA_MPI=1
|
||||
function profile_python() {
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: profile_python <python_script> [arguments for the script]"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local script_name=$(basename "$1" .py)
|
||||
local output_dir="prof_traces/$script_name"
|
||||
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
|
||||
|
||||
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||
# Remove characters '/' and '-' from folder name
|
||||
args=$(echo "$args" | tr -d '/-')
|
||||
output_dir="prof_traces/$script_name/$args"
|
||||
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
|
||||
fi
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
mkdir -p "$report_dir"
|
||||
|
||||
srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
|
||||
}
|
||||
|
||||
function run_python() {
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: run_python <python_script> [arguments for the script]"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local script_name=$(basename "$1" .py)
|
||||
local output_dir="traces/$script_name"
|
||||
|
||||
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||
# Remove characters '/' and '-' from folder name
|
||||
args=$(echo "$args" | tr -d '/-')
|
||||
output_dir="traces/$script_name/$args"
|
||||
fi
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
|
||||
srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
|
||||
}
|
||||
|
||||
|
||||
# run or profile
|
||||
|
||||
function slaunch() {
|
||||
run_python "$@"
|
||||
}
|
||||
|
||||
function plaunch() {
|
||||
profile_python "$@"
|
||||
}
|
||||
|
||||
# Echo des commandes lancees
|
||||
set -x
|
||||
|
||||
# Pour ne pas utiliser le /tmp
|
||||
export TMPDIR=$JOBSCRATCH
|
||||
# Pour contourner un bogue dans les versions actuelles de Nsight Systems
|
||||
# il est également nécessaire de créer un lien symbolique permettant de
|
||||
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
|
||||
ln -s $JOBSCRATCH /tmp/nvidia
|
||||
|
||||
# mpch=(128 256 512 1024 2048 4096)
|
||||
grid=(256 512 1024 2048 4096)
|
||||
precisions=(float32 float64)
|
||||
solvers=(lpt lfm)
|
||||
|
||||
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
|
||||
|
||||
if [ $num_gpu_per_node -eq 8 ]; then
|
||||
gpu_name="a100"
|
||||
else
|
||||
gpu_name="v100"
|
||||
fi
|
||||
|
||||
out_dir="pm_prof/$gpu_name/$nb_gpus"
|
||||
|
||||
echo "Output dir is : $out_dir"
|
||||
|
||||
for pr in "${precisions[@]}"; do
|
||||
for g in "${grid[@]}"; do
|
||||
for solver in "${solvers[@]}"; do
|
||||
launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
|
@ -1,40 +1,15 @@
|
|||
#!/bin/bash
|
||||
##########################################
|
||||
## SELECT EITHER tkc@a100 OR tkc@v100 ##
|
||||
##########################################
|
||||
#SBATCH --account tkc@v100
|
||||
##########################################
|
||||
#SBATCH --job-name=16N-V100Particle-Mesh # nom du job
|
||||
# Il est possible d'utiliser une autre partition que celle par default
|
||||
# en activant l'une des 5 directives suivantes :
|
||||
##########################################
|
||||
## SELECT EITHER a100 or v100-32g ##
|
||||
##########################################
|
||||
#SBATCH -C v100-32g
|
||||
##########################################
|
||||
#******************************************
|
||||
##########################################
|
||||
## SELECT Number of nodes and GPUs per node
|
||||
## For A100 ntasks-per-node and gres=gpu should be 8
|
||||
## For V100 ntasks-per-node and gres=gpu should be 4
|
||||
##########################################
|
||||
#SBATCH --nodes=1 # nombre de noeud
|
||||
#SBATCH --ntasks-per-node=1 # nombre de tache MPI par noeud (= nombre de GPU par noeud)
|
||||
#SBATCH --gres=gpu:1 # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
|
||||
##########################################
|
||||
## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
|
||||
## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
|
||||
## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
|
||||
##########################################
|
||||
##############################################################################################################################
|
||||
# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 benchmarks/particle_mesh_a100.slurm
|
||||
##############################################################################################################################
|
||||
#SBATCH --job-name=Particle-Mesh # nom du job
|
||||
#SBATCH --cpus-per-task=8 # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
|
||||
##########################################
|
||||
# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
|
||||
#SBATCH --hint=nomultithread # hyperthreading desactive
|
||||
#SBATCH --time=02:00:00 # temps d'execution maximum demande (HH:MM:SS)
|
||||
#SBATCH --time=04:00:00 # temps d'execution maximum demande (HH:MM:SS)
|
||||
#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
|
||||
#SBATCH --error=%x_%N_a100.out # nom du fichier d'erreur (ici commun avec la sortie)
|
||||
#SBATCH --qos=qos_gpu-dev
|
||||
#SBATCH --exclusive # ressources dediees
|
||||
##SBATCH --qos=qos_gpu-dev
|
||||
# Nettoyage des modules charges en interactif et herites par defaut
|
||||
num_nodes=$SLURM_JOB_NUM_NODES
|
||||
num_gpu_per_node=$SLURM_NTASKS_PER_NODE
|
||||
|
@ -44,12 +19,11 @@ nb_gpus=$(( num_nodes * num_gpu_per_node))
|
|||
|
||||
module purge
|
||||
|
||||
echo "Job constraint: $SLURM_JOB_CONSTRAINT"
|
||||
echo "Job partition: $SLURM_JOB_PARTITION"
|
||||
# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
|
||||
# pour avoir acces aux modules compatibles avec cette partition
|
||||
|
||||
if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
|
||||
if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then
|
||||
module load cpuarch/amd
|
||||
source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
|
||||
gpu_name=a100
|
||||
|
@ -67,9 +41,9 @@ echo "The number of nodes allocated for this job is: $num_nodes"
|
|||
echo "The number of GPUs allocated for this job is: $nb_gpus"
|
||||
|
||||
export EQX_ON_ERROR=nan
|
||||
export CUDA_ALLOC=1
|
||||
export ENABLE_PERFO_STEP=NVTX
|
||||
export MPI4JAX_USE_CUDA_MPI=1
|
||||
|
||||
function profile_python() {
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: profile_python <python_script> [arguments for the script]"
|
||||
|
@ -77,14 +51,14 @@ function profile_python() {
|
|||
fi
|
||||
|
||||
local script_name=$(basename "$1" .py)
|
||||
local output_dir="prof_traces/$script_name"
|
||||
local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name"
|
||||
local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
|
||||
|
||||
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||
# Remove characters '/' and '-' from folder name
|
||||
args=$(echo "$args" | tr -d '/-')
|
||||
output_dir="prof_traces/$script_name/$args"
|
||||
output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args"
|
||||
report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
|
||||
fi
|
||||
|
||||
|
@ -101,13 +75,13 @@ function run_python() {
|
|||
fi
|
||||
|
||||
local script_name=$(basename "$1" .py)
|
||||
local output_dir="traces/$script_name"
|
||||
local output_dir="traces/$gpu_name/$nb_gpus/$script_name"
|
||||
|
||||
if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
|
||||
local args=$(echo "${@:2}" | tr ' ' '_')
|
||||
# Remove characters '/' and '-' from folder name
|
||||
args=$(echo "$args" | tr -d '/-')
|
||||
output_dir="traces/$script_name/$args"
|
||||
output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args"
|
||||
fi
|
||||
|
||||
mkdir -p "$output_dir"
|
||||
|
@ -116,6 +90,7 @@ function run_python() {
|
|||
}
|
||||
|
||||
|
||||
|
||||
# run or profile
|
||||
|
||||
function slaunch() {
|
||||
|
@ -136,13 +111,9 @@ export TMPDIR=$JOBSCRATCH
|
|||
# faire pointer le répertoire /tmp/nvidia vers TMPDIR
|
||||
ln -s $JOBSCRATCH /tmp/nvidia
|
||||
|
||||
|
||||
|
||||
|
||||
# mpch=(128 256 512 1024 2048 4096)
|
||||
grid=(256 512 1024 2048 4096)
|
||||
grid=(256 512 1024 2048 4096 8192)
|
||||
precisions=(float32 float64)
|
||||
pdim="${pdims_table[$nb_gpus]}"
|
||||
solvers=(lpt lfm)
|
||||
|
||||
# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
|
||||
|
@ -154,17 +125,23 @@ else
|
|||
fi
|
||||
|
||||
out_dir="pm_prof/$gpu_name/$nb_gpus"
|
||||
|
||||
echo "Output dir is : $out_dir"
|
||||
|
||||
trace_dir="traces/$gpu_name/$nb_gpus/bench_pmwd"
|
||||
echo "Output dir is : $out_dir"
|
||||
echo "Trace dir is : $trace_dir"
|
||||
|
||||
for pr in "${precisions[@]}"; do
|
||||
for g in "${grid[@]}"; do
|
||||
for solver in "${solvers[@]}"; do
|
||||
slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f
|
||||
done
|
||||
# delete crash core dump files
|
||||
rm -f core.python.*
|
||||
done
|
||||
done
|
||||
|
||||
|
||||
|
||||
# # zip the output files and traces
|
||||
# tar -czvf $out_dir.tar.gz $out_dir
|
||||
# tar -czvf $trace_dir.tar.gz $trace_dir
|
||||
# # remove the output files and traces
|
||||
# rm -rf $out_dir $trace_dir
|
||||
#
|
|
@ -1,19 +1,21 @@
|
|||
#!/bin/bash
|
||||
# Run all slurms jobs
|
||||
nodes_v100=(1 2 4 8 16)
|
||||
nodes_a100=(1 2 4 8 16)
|
||||
nodes_v100=(1 2 4 8 16 32)
|
||||
nodes_a100=(1 2 4 8 16 32)
|
||||
|
||||
|
||||
for n in ${nodes_v100[@]}; do
|
||||
sbatch --nodes=$n --job-name=v100_$n-JAXPM particle_mesh_v100.slurm
|
||||
sbatch --account=tkc@v100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C v100-32g --job-name=JAXPM-$n-N-v100 particle_mesh.slurm
|
||||
done
|
||||
|
||||
for n in ${nodes_a100[@]}; do
|
||||
sbatch --nodes=$n --job-name=a100_$n-JAXPM particle_mesh_a100.slurm
|
||||
sbatch --account=tkc@a100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C a100 --job-name=JAXPM-$n-N-a100 particle_mesh.slurm
|
||||
done
|
||||
|
||||
# single GPUs
|
||||
sbatch --job-name=JAXPM-1GPU-V100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_v100.slurm
|
||||
sbatch --job-name=JAXPM-1GPU-A100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_a100.slurm
|
||||
sbatch --job-name=PMWD-v100 pmwd_v100.slurm
|
||||
sbatch --job-name=PMWD-a100 pmwd_a100.slurm
|
||||
sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=JAXPM-1GPU-V100 particle_mesh.slurm
|
||||
sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=JAXPM-1GPU-A100 particle_mesh.slurm
|
||||
sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100 --job-name=PMWD-1GPU-v100 pmwd_pm.slurm
|
||||
sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=PMWD-1GPU-a100 pmwd_pm.slurm
|
||||
|
||||
|
||||
|
|
|
@ -88,8 +88,8 @@ def get_halo_size(halo_size):
|
|||
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (halo_extents[0] > 0
|
||||
or halo_extents[1] > 0):
|
||||
if distributed and not (mesh.empty) and (halo_extents > 0
|
||||
or halo_extents > 0):
|
||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||
else:
|
||||
return x
|
||||
|
|
|
@ -50,15 +50,15 @@ def cic_paint_impl(mesh, displacement, weight=None):
|
|||
@partial(jax.jit, static_argnums=(2, ))
|
||||
def cic_paint(mesh, positions, halo_size=0, weight=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_padding)
|
||||
mesh = autoshmap(cic_paint_impl,
|
||||
in_specs=(P('x', 'y'), P('x', 'y'), P()),
|
||||
out_specs=P('x', 'y'))(mesh, positions, weight)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
mesh = slice_unpad(mesh, halo_size)
|
||||
halo_extents=halo_size // 2,
|
||||
halo_periods=True)
|
||||
mesh = slice_unpad(mesh, halo_padding)
|
||||
return mesh
|
||||
|
||||
|
||||
|
@ -95,11 +95,11 @@ def cic_read_impl(mesh, displacement):
|
|||
@partial(jax.jit, static_argnums=(2, ))
|
||||
def cic_read(mesh, displacement, halo_size=0):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_padding)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
halo_extents=halo_size//2,
|
||||
halo_periods=True)
|
||||
displacement = autoshmap(cic_read_impl,
|
||||
in_specs=(P('x', 'y'), P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(mesh, displacement)
|
||||
|
@ -160,16 +160,16 @@ def cic_paint_dx_impl(displacements, halo_size):
|
|||
@partial(jax.jit, static_argnums=(1, ))
|
||||
def cic_paint_dx(displacements, halo_size=0):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
||||
|
||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
|
||||
mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_padding),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(displacements)
|
||||
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
mesh = slice_unpad(mesh, halo_size)
|
||||
halo_extents=halo_size//2,
|
||||
halo_periods=True)
|
||||
mesh = slice_unpad(mesh, halo_padding)
|
||||
return mesh
|
||||
|
||||
|
||||
|
@ -194,12 +194,12 @@ def cic_read_dx_impl(mesh , halo_size):
|
|||
@partial(jax.jit, static_argnums=(1, ))
|
||||
def cic_read_dx(mesh, halo_size=0):
|
||||
# return mesh
|
||||
halo_size, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_size)
|
||||
halo_padding, halo_extents = get_halo_size(halo_size)
|
||||
mesh = slice_pad(mesh, halo_padding)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True, True))
|
||||
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size),
|
||||
halo_extents=halo_size//2,
|
||||
halo_periods=True)
|
||||
displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_padding),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(mesh)
|
||||
|
||||
|
|
|
@ -62,8 +62,8 @@ module load nvidia-nsight-systems/2024.1.1.59
|
|||
echo "The number of nodes allocated for this job is: $num_nodes"
|
||||
echo "The number of GPUs allocated for this job is: $nb_gpus"
|
||||
|
||||
export EQX_ON_ERROR=nan
|
||||
export CUDA_ALLOC=1
|
||||
export ENABLE_PERFO_STEP=NVTX
|
||||
export MPI4JAX_USE_CUDA_MPI=1
|
||||
|
||||
function profile_python() {
|
||||
if [ $# -lt 1 ]; then
|
||||
|
@ -122,6 +122,7 @@ set -x
|
|||
|
||||
declare -A pdims_table
|
||||
# Define the table
|
||||
pdims_table[1]="1x1"
|
||||
pdims_table[4]="2x2 1x4"
|
||||
pdims_table[8]="2x4 1x8"
|
||||
pdims_table[16]="2x8 1x16"
|
||||
|
|
Loading…
Add table
Reference in a new issue