mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 12:31:11 +00:00
Update for jaxDecomp pure JAX
This commit is contained in:
parent
831291c1f9
commit
2ea05a1cd6
9 changed files with 214 additions and 532 deletions
|
@ -62,8 +62,8 @@ module load nvidia-nsight-systems/2024.1.1.59
|
|||
echo "The number of nodes allocated for this job is: $num_nodes"
|
||||
echo "The number of GPUs allocated for this job is: $nb_gpus"
|
||||
|
||||
export EQX_ON_ERROR=nan
|
||||
export CUDA_ALLOC=1
|
||||
export ENABLE_PERFO_STEP=NVTX
|
||||
export MPI4JAX_USE_CUDA_MPI=1
|
||||
|
||||
function profile_python() {
|
||||
if [ $# -lt 1 ]; then
|
||||
|
@ -122,6 +122,7 @@ set -x
|
|||
|
||||
declare -A pdims_table
|
||||
# Define the table
|
||||
pdims_table[1]="1x1"
|
||||
pdims_table[4]="2x2 1x4"
|
||||
pdims_table[8]="2x4 1x8"
|
||||
pdims_table[16]="2x8 1x16"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue