Update for jaxDecomp pure JAX

This commit is contained in:
Wassim KABALAN 2024-08-07 23:52:13 +02:00
parent 831291c1f9
commit 2ea05a1cd6
9 changed files with 214 additions and 532 deletions

View file

@ -62,8 +62,8 @@ module load nvidia-nsight-systems/2024.1.1.59
echo "The number of nodes allocated for this job is: $num_nodes"
echo "The number of GPUs allocated for this job is: $nb_gpus"
export EQX_ON_ERROR=nan
export CUDA_ALLOC=1
export ENABLE_PERFO_STEP=NVTX
export MPI4JAX_USE_CUDA_MPI=1
function profile_python() {
if [ $# -lt 1 ]; then
@ -122,6 +122,7 @@ set -x
declare -A pdims_table
# Define the table
pdims_table[1]="1x1"
pdims_table[4]="2x2 1x4"
pdims_table[8]="2x4 1x8"
pdims_table[16]="2x8 1x16"