From 2ea05a1cd6a451dec406b222b69928c5cf6b6956 Mon Sep 17 00:00:00 2001
From: Wassim KABALAN <wassim@apc.in2p3.fr>
Date: Wed, 7 Aug 2024 23:52:13 +0200
Subject: [PATCH] Update for jaxDecomp pure JAX

---
 benchmarks/bench_pm.py                        | 187 ++++++++++++------
 ...le_mesh_a100.slurm => particle_mesh.slurm} |  72 +++----
 benchmarks/particle_mesh_v100.slurm           | 184 -----------------
 benchmarks/pmwd_a100.slurm                    | 165 ----------------
 benchmarks/{pmwd_v100.slurm => pmwd_pm.slurm} |  73 +++----
 benchmarks/run_all_jobs.sh                    |  18 +-
 jaxpm/distributed.py                          |   4 +-
 jaxpm/painting.py                             |  38 ++--
 scripts/particle_mesh.slurm                   |   5 +-
 9 files changed, 214 insertions(+), 532 deletions(-)
 rename benchmarks/{particle_mesh_a100.slurm => particle_mesh.slurm} (65%)
 delete mode 100644 benchmarks/particle_mesh_v100.slurm
 delete mode 100644 benchmarks/pmwd_a100.slurm
 rename benchmarks/{pmwd_v100.slurm => pmwd_pm.slurm} (64%)

diff --git a/benchmarks/bench_pm.py b/benchmarks/bench_pm.py
index 9b916f9..c80fdd5 100644
--- a/benchmarks/bench_pm.py
+++ b/benchmarks/bench_pm.py
@@ -25,7 +25,7 @@ from jax.sharding import PartitionSpec as P
 from jaxpm.kernels import interpolate_power_spectrum
 from jaxpm.painting import cic_paint_dx
 from jaxpm.pm import linear_field, lpt, make_ode_fn
-
+from jax import make_jaxpr
 
 
 def run_simulation(mesh_shape,
@@ -33,7 +33,10 @@ def run_simulation(mesh_shape,
                    halo_size,
                    solver_choice,
                    iterations,
-                   pdims=None):
+                   hlo_print,
+                   trace,
+                   pdims=None,
+                   output_path="."):
 
     @jax.jit
     def simulate(omega_c, sigma8):
@@ -60,7 +63,6 @@ def run_simulation(mesh_shape,
             solver = Tsit5()
         elif solver_choice == "lpt":
             lpt_field = cic_paint_dx(dx, halo_size=halo_size)
-            print(f"TYPE of lpt_field: {type(lpt_field)}")
             return lpt_field, {"num_steps": 0}
         else:
             raise ValueError(
@@ -92,16 +94,37 @@ def run_simulation(mesh_shape,
 
     def run():
         # Warm start
-        chrono_fun = Timer()
-        RangePush("warmup")
-        final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
-        RangePop()
-        sync_global_devices("warmup")
-        for i in range(iterations):
-            RangePush(f"sim iter {i}")
-            final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
-            RangePop()
-        return final_field, stats, chrono_fun
+        if hlo_print:
+          jaxpr = make_jaxpr(simulate)(0.32, 0.8)
+          lowered = jax.jit(simulate).lower(0.32, 0.8)
+          lower_as_text = lowered.as_text()
+          compiled = lowered.compile()
+          compiled_again = jax.jit(simulate).lower(0.32, 0.8).compile()
+          return jaxpr ,  compiled , compiled_again
+        elif trace:
+          jit_output = f"{output_path}/jit_trace"
+          first_run_output = f"{output_path}/first_run_trace"
+          second_run_output = f"{output_path}/second_run_trace"
+          with jax.profiler.trace(jit_output , create_perfetto_trace=True):
+            final_field, stats = simulate(0.32, 0.8)
+            final_field.block_until_ready()
+          with jax.profiler.trace(first_run_output , create_perfetto_trace=True):
+            final_field, stats = simulate(0.32, 0.8)
+            final_field.block_until_ready()
+          with jax.profiler.trace(second_run_output , create_perfetto_trace=True):
+            final_field, stats = simulate(0.32, 0.8)
+            final_field.block_until_ready()
+        else:         
+          chrono_fun = Timer()
+          RangePush("warmup")
+          final_field, stats = chrono_fun.chrono_jit(simulate, 0.32, 0.8 , ndarray_arg = 0)
+          RangePop()
+          sync_global_devices("warmup")
+          for i in range(iterations):
+              RangePush(f"sim iter {i}")
+              final_field, stats = chrono_fun.chrono_fun(simulate, 0.32, 0.8 , ndarray_arg = 0)
+              RangePop()
+          return final_field, stats, chrono_fun
 
     if jax.device_count() > 1:
         devices = mesh_utils.create_device_mesh(pdims)
@@ -143,7 +166,7 @@ if __name__ == "__main__":
                         '--halo_size',
                         type=int,
                         help='Halo size',
-                        required=True)
+                        default=None)
     parser.add_argument('-s',
                         '--solver',
                         type=str,
@@ -153,12 +176,7 @@ if __name__ == "__main__":
                             "LeapfrogMidpoint", "leapfrogmidpoint", "lfm",
                             "lpt"
                         ],
-                        required=True)
-    parser.add_argument('-i',
-                        '--iterations',
-                        type=int,
-                        help='Number of iterations',
-                        default=10)
+                        default="lpt")
     parser.add_argument('-o',
                         '--output_path',
                         type=str,
@@ -173,16 +191,33 @@ if __name__ == "__main__":
                         type=int,
                         help='Number of nodes',
                         default=1)
+    parser.add_argument('-i',
+                        '--iterations',
+                        type=int,
+                        help='Number of iterations',
+                        default=10)
+    group = parser.add_mutually_exclusive_group()
+    group.add_argument('-hlo',
+                        '--hlo_print',
+                        action='store_true',
+                        help='Print hlo generated by XLA')
+    group.add_argument('-t',
+                        '--trace',
+                        action='store_true',
+                        help='Profile using tensorboard')
     
     args = parser.parse_args()
     mesh_size = args.mesh_size
     box_size = [args.box_size] * 3
-    halo_size = args.halo_size
+    halo_size = args.mesh_size // 8 if args.halo_size is None else args.halo_size
     solver_choice = args.solver
     iterations = args.iterations
     output_path = args.output_path
     os.makedirs(output_path, exist_ok=True)
-    
+    hlo_print = args.hlo_print
+    trace = args.trace
+    nb_gpus = jax.device_count()
+
     print(f"solver choice: {solver_choice}")
     match solver_choice:
         case "Dopri5" | "dopri5"| "d5":
@@ -205,45 +240,81 @@ if __name__ == "__main__":
     if args.pdims:
         pdims = tuple(map(int, args.pdims.split("x")))
     else:
-        pdims = (1, 1)
+        pdims = (1, jax.device_count())
+    pdm_str = f"{pdims[0]}x{pdims[1]}"
 
     mesh_shape = [mesh_size] * 3
-
-    final_field , stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, pdims)
     
-    print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
+    if trace:
+      trace_folder = f"{output_path}/profiling/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
+      os.makedirs(trace_folder, exist_ok=True)
+      run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, trace_folder)
+      print(f"Profiling done! Check {trace_folder}")
+    elif hlo_print:
+      hlo_folder = f"{output_path}/hlo/jaxpm/{nb_gpus}/{mesh_shape[0]}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
+      os.makedirs(hlo_folder, exist_ok=True)
+      jaxpr , compiled , compiled2 = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, hlo_folder)
+      print(f"type of memory analysis {type(compiled.memory_analysis())}")
+      print(f"memory analysis {compiled.memory_analysis()}")
+      print(f"memory analysis again {compiled2.memory_analysis()}")
+      jax.tree.map(lambda x: print(x), compiled.memory_analysis())
+      with open(f'{hlo_folder}/hlo_jaxpm.md', 'w') as f:
+          f.write(f"# JAXPM HLO\n")
+          f.write(f"## Args: {args}\n")
+          f.write(f"## JAXPR is \n")
+          f.write(f'---\n')
+          f.write(f"{jaxpr}\n")
+          f.write(f'---\n')
+          f.write(f"Lowered as text is \n")
+          f.write(f'---\n')
+          # f.write(f"{lower_as_text}\n")
+          f.write(f'---\n')
+          f.write(f"Compiled is \n")
+          f.write(f'---\n')
+          f.write(f"{compiled.as_text()}\n")
+          f.write(f"Cost analysis is \n")
+          f.write(f'---\n')
+          f.write(f"{compiled.cost_analysis()[0]['flops']}\n")
+          f.write(f'---\n')
+          f.write(f"Memory analysis is \n")
+          f.write(f'---\n')
+          f.write(f"{compiled.memory_analysis()}\n")
+          f.write(f'---\n')
 
-    metadata = {
-      'rank': rank,
-      'function_name': f'JAXPM-{solver_choice}',
-      'precision': args.precision,
-      'x': str(mesh_size),
-      'y': str(mesh_size),
-      'z': str(stats["num_steps"]),
-      'px': str(pdims[0]),
-      'py': str(pdims[1]),
-      'backend': 'NCCL',
-      'nodes': str(args.nodes)
-  }
-    # Print the results to a CSV file
-    chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
+      print(f"Saved HLO to {hlo_folder}")
+    else:
+      final_field, stats, chrono_fun = run_simulation(mesh_shape, box_size, halo_size, solver_choice, iterations, hlo_print, trace, pdims, output_path)
+      print(f"shape of final_field {final_field.shape} and sharding spec {final_field.sharding} and local shape {final_field.addressable_data(0).shape}")
+      metadata = {
+        'rank': rank,
+        'function_name': f'JAXPM-{solver_choice}',
+        'precision': args.precision,
+        'x': str(mesh_size),
+        'y': str(mesh_size),
+        'z': str(stats["num_steps"]),
+        'px': str(pdims[0]),
+        'py': str(pdims[1]),
+        'backend': 'NCCL',
+        'nodes': str(args.nodes)
+      }
+      # Print the results to a CSV file
+      chrono_fun.print_to_csv(f'{output_path}/jaxpm_benchmark.csv', **metadata)
 
-    # Save the final field
-    nb_gpus = jax.device_count()
-    pdm_str = f"{pdims[0]}x{pdims[1]}"
-    field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
-    os.makedirs(field_folder, exist_ok=True)
-    with open(f'{field_folder}/jaxpm.log', 'w') as f:
-        f.write(f"Args: {args}\n")
-        f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
-        for i , time in enumerate(chrono_fun.times):
-          f.write(f"Time {i}: {time:.4f} ms\n")
-        f.write(f"Stats: {stats}\n")
-    if args.save_fields:
-        np.save(f'{field_folder}/final_field_0_{rank}.npy',
-                final_field.addressable_data(0))
+      # Save the final field
 
-    print(f"Finished! ")
-    print(f"Stats {stats}")
-    print(f"Saving to {output_path}/jax_pm_benchmark.csv")
-    print(f"Saving field and logs in {field_folder}")
+      field_folder = f"{output_path}/final_field/jaxpm/{nb_gpus}/{mesh_size}_{int(box_size[0])}/{pdm_str}/{solver_choice}/halo_{halo_size}"
+      os.makedirs(field_folder, exist_ok=True)
+      with open(f'{field_folder}/jaxpm.log', 'w') as f:
+          f.write(f"Args: {args}\n")
+          f.write(f"JIT time: {chrono_fun.jit_time:.4f} ms\n")
+          for i , time in enumerate(chrono_fun.times):
+            f.write(f"Time {i}: {time:.4f} ms\n")
+          f.write(f"Stats: {stats}\n")
+      if args.save_fields:
+          np.save(f'{field_folder}/final_field_0_{rank}.npy',
+                  final_field.addressable_data(0))
+
+      print(f"Finished! ")
+      print(f"Stats {stats}")
+      print(f"Saving to {output_path}/jax_pm_benchmark.csv")
+      print(f"Saving field and logs in {field_folder}")
diff --git a/benchmarks/particle_mesh_a100.slurm b/benchmarks/particle_mesh.slurm
similarity index 65%
rename from benchmarks/particle_mesh_a100.slurm
rename to benchmarks/particle_mesh.slurm
index 65930b2..ab55b05 100644
--- a/benchmarks/particle_mesh_a100.slurm
+++ b/benchmarks/particle_mesh.slurm
@@ -1,40 +1,15 @@
 #!/bin/bash
-##########################################
-## SELECT EITHER tkc@a100 OR tkc@v100 ##
-##########################################
-#SBATCH --account tkc@a100
-##########################################
-#SBATCH --job-name=1N-FFT-Mesh   # nom du job
-# Il est possible d'utiliser une autre partition que celle par default
-# en activant l'une des 5 directives suivantes :
-##########################################
-## SELECT EITHER a100  or v100-32g   ##
-##########################################
-#SBATCH -C a100
-##########################################
-#******************************************
-##########################################
-## SELECT Number of nodes and GPUs per node
-## For A100 ntasks-per-node and gres=gpu should be 8
-## For V100 ntasks-per-node and gres=gpu should be 4
-##########################################
-#SBATCH --nodes=1                    # nombre de noeud
-#SBATCH --ntasks-per-node=8          # nombre de tache MPI par noeud (= nombre de GPU par noeud)
-#SBATCH --gres=gpu:8                # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
-##########################################
-## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
-## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
-## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
-##########################################
+##############################################################################################################################
+# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100  benchmarks/particle_mesh_a100.slurm 
+##############################################################################################################################
+#SBATCH --job-name=Particle-Mesh   # nom du job
 #SBATCH --cpus-per-task=8           # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
-##########################################
-# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
 #SBATCH --hint=nomultithread         # hyperthreading desactive
 #SBATCH --time=04:00:00              # temps d'execution maximum demande (HH:MM:SS)
 #SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
-#SBATCH --error=%x_%N_a100.out  # nom du fichier d'erreur (ici commun avec la sortie)
+#SBATCH --error=%x_%N_a100.err  # nom du fichier d'erreur (ici commun avec la sortie)
+#SBATCH --exclusive                  # ressources dediees
 ##SBATCH --qos=qos_gpu-dev
-## SBATCH --exclusive                  # ressources dediees
 # Nettoyage des modules charges en interactif et herites par defaut
 num_nodes=$SLURM_JOB_NUM_NODES
 num_gpu_per_node=$SLURM_NTASKS_PER_NODE
@@ -44,12 +19,11 @@ nb_gpus=$(( num_nodes * num_gpu_per_node))
 
 module purge
 
-echo "Job constraint: $SLURM_JOB_CONSTRAINT"
 echo "Job partition: $SLURM_JOB_PARTITION"
 # Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
 # pour avoir acces aux modules compatibles avec cette partition
 
-if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
+if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then
     module load cpuarch/amd
     source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
     gpu_name=a100
@@ -66,8 +40,10 @@ module load nvidia-nsight-systems/2024.1.1.59
 echo "The number of nodes allocated for this job is: $num_nodes"
 echo "The number of GPUs allocated for this job is: $nb_gpus"
 
+export EQX_ON_ERROR=nan
 export ENABLE_PERFO_STEP=NVTX
 export MPI4JAX_USE_CUDA_MPI=1
+
 function profile_python() {
     if [ $# -lt 1 ]; then
         echo "Usage: profile_python <python_script> [arguments for the script]"
@@ -75,14 +51,14 @@ function profile_python() {
     fi
 
     local script_name=$(basename "$1" .py)
-    local output_dir="prof_traces/$script_name"
+    local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name"
     local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
 
     if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
         local args=$(echo "${@:2}" | tr ' ' '_')
         # Remove characters '/' and '-' from folder name
         args=$(echo "$args" | tr -d '/-')
-        output_dir="prof_traces/$script_name/$args"
+        output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args"
         report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
     fi
 
@@ -99,13 +75,13 @@ function run_python() {
     fi
 
     local script_name=$(basename "$1" .py)
-    local output_dir="traces/$script_name"
+    local output_dir="traces/$gpu_name/$nb_gpus/$script_name"
 
     if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
         local args=$(echo "${@:2}" | tr ' ' '_')
         # Remove characters '/' and '-' from folder name
         args=$(echo "$args" | tr -d '/-')
-        output_dir="traces/$script_name/$args"
+        output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args"
     fi
 
     mkdir -p "$output_dir"
@@ -142,12 +118,11 @@ pdims_table[8]="2x4 1x8 8x1 4x2"
 pdims_table[16]="4x4 1x16 16x1"
 pdims_table[32]="4x8 8x4 1x32 32x1"
 pdims_table[64]="8x8 16x4 1x64 64x1"
-pdims_table[128]="8x16 16x8 4x32 32x4 1x128 128x1 2x64 64x2"
-pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4"
-
+pdims_table[128]="8x16 16x8 1x128 128x1"
+pdims_table[256]="16x16 1x256 256x1"
 
 # mpch=(128 256 512 1024 2048 4096)
-grid=(256 512 1024 2048 4096)
+grid=(256 512 1024 2048 4096 8192)
 precisions=(float32 float64)
 pdim="${pdims_table[$nb_gpus]}"
 solvers=(lpt lfm)
@@ -164,8 +139,9 @@ fi
 
 # GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
 out_dir="pm_prof/$gpu_name/$nb_gpus"
-
-echo "Output dir is  : $out_dir"
+trace_dir="traces/$gpu_name/$nb_gpus/bench_pm"
+echo "Output dir is  : $out_dir" 
+echo "Trace dir is  : $trace_dir"
 
 for pr in "${precisions[@]}"; do
     for g in "${grid[@]}"; do
@@ -175,9 +151,13 @@ for pr in "${precisions[@]}"; do
                 slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes
             done
         done
+        # delete crash core dump files
+        rm -f core.python.*
     done
 done
 
-
-
-
+# # zip the output files and traces
+# tar -czvf $out_dir.tar.gz $out_dir
+# tar -czvf $trace_dir.tar.gz $trace_dir
+# # remove the output files and traces
+# rm -rf $out_dir $trace_dir
diff --git a/benchmarks/particle_mesh_v100.slurm b/benchmarks/particle_mesh_v100.slurm
deleted file mode 100644
index 9eeb610..0000000
--- a/benchmarks/particle_mesh_v100.slurm
+++ /dev/null
@@ -1,184 +0,0 @@
-#!/bin/bash
-##########################################
-## SELECT EITHER tkc@a100 OR tkc@v100 ##
-##########################################
-#SBATCH --account tkc@v100
-##########################################
-#SBATCH --job-name=V100Particle-Mesh   # nom du job
-# Il est possible d'utiliser une autre partition que celle par default
-# en activant l'une des 5 directives suivantes :
-##########################################
-## SELECT EITHER a100  or v100-32g   ##
-##########################################
-#SBATCH -C v100-32g
-##########################################
-#******************************************
-##########################################
-## SELECT Number of nodes and GPUs per node
-## For A100 ntasks-per-node and gres=gpu should be 8
-## For V100 ntasks-per-node and gres=gpu should be 4
-##########################################
-#SBATCH --nodes=1                    # nombre de noeud
-#SBATCH --ntasks-per-node=4          # nombre de tache MPI par noeud (= nombre de GPU par noeud)
-#SBATCH --gres=gpu:4                # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
-##########################################
-## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
-## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
-## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
-##########################################
-#SBATCH --cpus-per-task=8           # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
-##########################################
-# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
-#SBATCH --hint=nomultithread         # hyperthreading desactive
-#SBATCH --time=02:00:00              # temps d'execution maximum demande (HH:MM:SS)
-#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
-#SBATCH --error=%x_%N_a100.out  # nom du fichier d'erreur (ici commun avec la sortie)
-#SBATCH --qos=qos_gpu-dev
-#SBATCH --exclusive                  # ressources dediees
-# Nettoyage des modules charges en interactif et herites par defaut
-num_nodes=$SLURM_JOB_NUM_NODES
-num_gpu_per_node=$SLURM_NTASKS_PER_NODE
-OUTPUT_FOLDER_ARGS=1
-# Calculate the number of GPUs
-nb_gpus=$(( num_nodes * num_gpu_per_node))
-
-module purge
-
-echo "Job constraint: $SLURM_JOB_CONSTRAINT"
-echo "Job partition: $SLURM_JOB_PARTITION"
-# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
-# pour avoir acces aux modules compatibles avec cette partition
-
-if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
-    module load cpuarch/amd
-    source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
-    gpu_name=a100
-else
-    source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
-    gpu_name=v100
-fi
-
-# Chargement des modules
-module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda  openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
-module load nvidia-nsight-systems/2024.1.1.59
-
-
-echo "The number of nodes allocated for this job is: $num_nodes"
-echo "The number of GPUs allocated for this job is: $nb_gpus"
-
-export EQX_ON_ERROR=nan
-export CUDA_ALLOC=1
-export ENABLE_PERFO_STEP=NVTX
-export MPI4JAX_USE_CUDA_MPI=1
-function profile_python() {
-    if [ $# -lt 1 ]; then
-        echo "Usage: profile_python <python_script> [arguments for the script]"
-        return 1
-    fi
-
-    local script_name=$(basename "$1" .py)
-    local output_dir="prof_traces/$script_name"
-    local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
-
-    if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
-        local args=$(echo "${@:2}" | tr ' ' '_')
-        # Remove characters '/' and '-' from folder name
-        args=$(echo "$args" | tr -d '/-')
-        output_dir="prof_traces/$script_name/$args"
-        report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
-    fi
-
-    mkdir -p "$output_dir"
-    mkdir -p "$report_dir"
-
-    srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
-}
-
-function run_python() {
-    if [ $# -lt 1 ]; then
-        echo "Usage: run_python <python_script> [arguments for the script]"
-        return 1
-    fi
-
-    local script_name=$(basename "$1" .py)
-    local output_dir="traces/$script_name"
-
-    if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
-        local args=$(echo "${@:2}" | tr ' ' '_')
-        # Remove characters '/' and '-' from folder name
-        args=$(echo "$args" | tr -d '/-')
-        output_dir="traces/$script_name/$args"
-    fi
-
-    mkdir -p "$output_dir"
-
-    srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
-}
-
-
-# run or profile
-
-function slaunch() {
-    run_python "$@"
-}
-
-function plaunch() {
-    profile_python "$@"
-}
-
-# Echo des commandes lancees
-set -x
-
-# Pour ne pas utiliser le /tmp
-export TMPDIR=$JOBSCRATCH
-# Pour contourner un bogue dans les versions actuelles de Nsight Systems
-# il est également nécessaire de créer un lien symbolique permettant de
-# faire pointer le répertoire /tmp/nvidia vers TMPDIR
-ln -s $JOBSCRATCH /tmp/nvidia
-
-declare -A pdims_table
-# Define the table
-pdims_table[1]="1x1"
-pdims_table[4]="2x2 1x4 4x1"
-pdims_table[8]="2x4 1x8 8x1 4x2"
-pdims_table[16]="4x4 1x16 16x1"
-pdims_table[32]="4x8 8x4 1x32 32x1"
-pdims_table[64]="8x8 16x4 1x64 64x1"
-pdims_table[128]="8x16 16x8 4x32 1x128 128x1"
-pdims_table[160]="8x20 20x8 16x10 10x16 5x32 32x5 1x160 160x1 2x80 80x2 4x40 40x4"
-
-
-# mpch=(128 256 512 1024 2048 4096)
-grid=(256 512 1024 2048 4096)
-precisions=(float32 float64)
-pdim="${pdims_table[$nb_gpus]}"
-solvers=(lpt lfm)
-echo "pdims: $pdim"
-
-# Check if pdims is not empty
-if [ -z "$pdim" ]; then
-    echo "pdims is empty"
-    echo "Number of gpus has to be 8, 16, 32, 64, 128 or 160"
-    echo "Number of nodes selected: $num_nodes"
-    echo "Number of gpus per node: $num_gpu_per_node"
-    exit 1
-fi
-
-# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
-out_dir="pm_prof/$gpu_name/$nb_gpus"
-
-echo "Output dir is  : $out_dir"
-
-for pr in "${precisions[@]}"; do
-    for g in "${grid[@]}"; do
-        for solver in "${solvers[@]}"; do
-            for p in $pdim; do
-                halo_size=$((g / 4))
-                slaunch bench_pm.py -m $g -b $g -p $p -hs $halo_size -pr $pr -s $solver -i 4 -o $out_dir -f -n $num_nodes
-            done
-        done
-    done
-done
-
-
-
diff --git a/benchmarks/pmwd_a100.slurm b/benchmarks/pmwd_a100.slurm
deleted file mode 100644
index 99c64c7..0000000
--- a/benchmarks/pmwd_a100.slurm
+++ /dev/null
@@ -1,165 +0,0 @@
-#!/bin/bash
-##########################################
-## SELECT EITHER tkc@a100 OR tkc@v100 ##
-##########################################
-#SBATCH --account tkc@a100
-##########################################
-#SBATCH --job-name=1N-FFT-Mesh   # nom du job
-# Il est possible d'utiliser une autre partition que celle par default
-# en activant l'une des 5 directives suivantes :
-##########################################
-## SELECT EITHER a100  or v100-32g   ##
-##########################################
-#SBATCH -C a100
-##########################################
-#******************************************
-##########################################
-## SELECT Number of nodes and GPUs per node
-## For A100 ntasks-per-node and gres=gpu should be 8
-## For V100 ntasks-per-node and gres=gpu should be 4
-##########################################
-#SBATCH --nodes=1                    # nombre de noeud
-#SBATCH --ntasks-per-node=1          # nombre de tache MPI par noeud (= nombre de GPU par noeud)
-#SBATCH --gres=gpu:1                # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
-##########################################
-## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
-## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
-## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
-##########################################
-#SBATCH --cpus-per-task=8           # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
-##########################################
-# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
-#SBATCH --hint=nomultithread         # hyperthreading desactive
-#SBATCH --time=04:00:00              # temps d'execution maximum demande (HH:MM:SS)
-#SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
-#SBATCH --error=%x_%N_a100.out  # nom du fichier d'erreur (ici commun avec la sortie)
-##SBATCH --qos=qos_gpu-dev
-## SBATCH --exclusive                  # ressources dediees
-# Nettoyage des modules charges en interactif et herites par defaut
-num_nodes=$SLURM_JOB_NUM_NODES
-num_gpu_per_node=$SLURM_NTASKS_PER_NODE
-OUTPUT_FOLDER_ARGS=1
-# Calculate the number of GPUs
-nb_gpus=$(( num_nodes * num_gpu_per_node))
-
-module purge
-
-echo "Job constraint: $SLURM_JOB_CONSTRAINT"
-echo "Job partition: $SLURM_JOB_PARTITION"
-# Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
-# pour avoir acces aux modules compatibles avec cette partition
-
-if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
-    module load cpuarch/amd
-    source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
-    gpu_name=a100
-else
-    source /gpfsdswork/projects/rech/tkc/commun/venv/v100/bin/activate
-    gpu_name=v100
-fi
-
-# Chargement des modules
-module load nvidia-compilers/23.9 cuda/12.2.0 cudnn/8.9.7.29-cuda  openmpi/4.1.5-cuda nccl/2.18.5-1-cuda cmake
-module load nvidia-nsight-systems/2024.1.1.59
-
-
-echo "The number of nodes allocated for this job is: $num_nodes"
-echo "The number of GPUs allocated for this job is: $nb_gpus"
-
-export EQX_ON_ERROR=nan
-export CUDA_ALLOC=1
-export ENABLE_PERFO_STEP=NVTX
-export MPI4JAX_USE_CUDA_MPI=1
-function profile_python() {
-    if [ $# -lt 1 ]; then
-        echo "Usage: profile_python <python_script> [arguments for the script]"
-        return 1
-    fi
-
-    local script_name=$(basename "$1" .py)
-    local output_dir="prof_traces/$script_name"
-    local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
-
-    if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
-        local args=$(echo "${@:2}" | tr ' ' '_')
-        # Remove characters '/' and '-' from folder name
-        args=$(echo "$args" | tr -d '/-')
-        output_dir="prof_traces/$script_name/$args"
-        report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
-    fi
-
-    mkdir -p "$output_dir"
-    mkdir -p "$report_dir"
-
-    srun timeout 10m nsys profile -t cuda,nvtx,osrt,mpi -o "$report_dir/report_rank%q{SLURM_PROCID}" python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
-}
-
-function run_python() {
-    if [ $# -lt 1 ]; then
-        echo "Usage: run_python <python_script> [arguments for the script]"
-        return 1
-    fi
-
-    local script_name=$(basename "$1" .py)
-    local output_dir="traces/$script_name"
-
-    if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
-        local args=$(echo "${@:2}" | tr ' ' '_')
-        # Remove characters '/' and '-' from folder name
-        args=$(echo "$args" | tr -d '/-')
-        output_dir="traces/$script_name/$args"
-    fi
-
-    mkdir -p "$output_dir"
-
-    srun timeout 10m python "$@" > "$output_dir/$script_name.out" 2> "$output_dir/$script_name.err" || true
-}
-
-
-# run or profile
-
-function slaunch() {
-    run_python "$@"
-}
-
-function plaunch() {
-    profile_python "$@"
-}
-
-# Echo des commandes lancees
-set -x
-
-# Pour ne pas utiliser le /tmp
-export TMPDIR=$JOBSCRATCH
-# Pour contourner un bogue dans les versions actuelles de Nsight Systems
-# il est également nécessaire de créer un lien symbolique permettant de
-# faire pointer le répertoire /tmp/nvidia vers TMPDIR
-ln -s $JOBSCRATCH /tmp/nvidia
-
-# mpch=(128 256 512 1024 2048 4096)
-grid=(256 512 1024 2048 4096)
-precisions=(float32 float64)
-solvers=(lpt lfm)
-
-# GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
-
-if [ $num_gpu_per_node -eq 8 ]; then
-    gpu_name="a100"
-else
-    gpu_name="v100"
-fi
-
-out_dir="pm_prof/$gpu_name/$nb_gpus"
-
-echo "Output dir is  : $out_dir"
-
-for pr in "${precisions[@]}"; do
-    for g in "${grid[@]}"; do
-        for solver in "${solvers[@]}"; do
-           launch bench_pmwd.py -m $g -b $g -p $p -pr $pr -s $solver -i 4 -o $out_dir -f  
-        done
-    done
-done
-
-
-
diff --git a/benchmarks/pmwd_v100.slurm b/benchmarks/pmwd_pm.slurm
similarity index 64%
rename from benchmarks/pmwd_v100.slurm
rename to benchmarks/pmwd_pm.slurm
index 4c58db5..3e2cee8 100644
--- a/benchmarks/pmwd_v100.slurm
+++ b/benchmarks/pmwd_pm.slurm
@@ -1,40 +1,15 @@
 #!/bin/bash
-##########################################
-## SELECT EITHER tkc@a100 OR tkc@v100 ##
-##########################################
-#SBATCH --account tkc@v100
-##########################################
-#SBATCH --job-name=16N-V100Particle-Mesh   # nom du job
-# Il est possible d'utiliser une autre partition que celle par default
-# en activant l'une des 5 directives suivantes :
-##########################################
-## SELECT EITHER a100  or v100-32g   ##
-##########################################
-#SBATCH -C v100-32g
-##########################################
-#******************************************
-##########################################
-## SELECT Number of nodes and GPUs per node
-## For A100 ntasks-per-node and gres=gpu should be 8
-## For V100 ntasks-per-node and gres=gpu should be 4
-##########################################
-#SBATCH --nodes=1                    # nombre de noeud
-#SBATCH --ntasks-per-node=1          # nombre de tache MPI par noeud (= nombre de GPU par noeud)
-#SBATCH --gres=gpu:1                # nombre de GPU par nœud (max 8 avec gpu_p2, gpu_p5)
-##########################################
-## Le nombre de CPU par tache doit etre adapte en fonction de la partition utilisee. Sachant
-## qu'ici on ne reserve qu'un seul GPU par tache (soit 1/4 ou 1/8 des GPU du noeud suivant
-## la partition), l'ideal est de reserver 1/4 ou 1/8 des CPU du noeud pour chaque tache:
-##########################################
+##############################################################################################################################
+# USAGE:sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100  benchmarks/particle_mesh_a100.slurm 
+##############################################################################################################################
+#SBATCH --job-name=Particle-Mesh   # nom du job
 #SBATCH --cpus-per-task=8           # nombre de CPU par tache pour gpu_p5 (1/8 du noeud 8-GPU)
-##########################################
-# /!\ Attention, "multithread" fait reference a l'hyperthreading dans la terminologie Slurm
 #SBATCH --hint=nomultithread         # hyperthreading desactive
-#SBATCH --time=02:00:00              # temps d'execution maximum demande (HH:MM:SS)
+#SBATCH --time=04:00:00              # temps d'execution maximum demande (HH:MM:SS)
 #SBATCH --output=%x_%N_a100.out # nom du fichier de sortie
 #SBATCH --error=%x_%N_a100.out  # nom du fichier d'erreur (ici commun avec la sortie)
-#SBATCH --qos=qos_gpu-dev
 #SBATCH --exclusive                  # ressources dediees
+##SBATCH --qos=qos_gpu-dev
 # Nettoyage des modules charges en interactif et herites par defaut
 num_nodes=$SLURM_JOB_NUM_NODES
 num_gpu_per_node=$SLURM_NTASKS_PER_NODE
@@ -44,12 +19,11 @@ nb_gpus=$(( num_nodes * num_gpu_per_node))
 
 module purge
 
-echo "Job constraint: $SLURM_JOB_CONSTRAINT"
 echo "Job partition: $SLURM_JOB_PARTITION"
 # Decommenter la commande module suivante si vous utilisez la partition "gpu_p5"
 # pour avoir acces aux modules compatibles avec cette partition
 
-if [ $SLURM_JOB_PARTITION -eq gpu_p5 ]; then
+if [[ "$SLURM_JOB_PARTITION" == "gpu_p5" ]]; then
     module load cpuarch/amd
     source /gpfsdswork/projects/rech/tkc/commun/venv/a100/bin/activate
     gpu_name=a100
@@ -67,9 +41,9 @@ echo "The number of nodes allocated for this job is: $num_nodes"
 echo "The number of GPUs allocated for this job is: $nb_gpus"
 
 export EQX_ON_ERROR=nan
-export CUDA_ALLOC=1
 export ENABLE_PERFO_STEP=NVTX
 export MPI4JAX_USE_CUDA_MPI=1
+
 function profile_python() {
     if [ $# -lt 1 ]; then
         echo "Usage: profile_python <python_script> [arguments for the script]"
@@ -77,14 +51,14 @@ function profile_python() {
     fi
 
     local script_name=$(basename "$1" .py)
-    local output_dir="prof_traces/$script_name"
+    local output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name"
     local report_dir="out_prof/$gpu_name/$nb_gpus/$script_name"
 
     if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
         local args=$(echo "${@:2}" | tr ' ' '_')
         # Remove characters '/' and '-' from folder name
         args=$(echo "$args" | tr -d '/-')
-        output_dir="prof_traces/$script_name/$args"
+        output_dir="prof_traces/$gpu_name/$nb_gpus/$script_name/$args"
         report_dir="out_prof/$gpu_name/$nb_gpus/$script_name/$args"
     fi
 
@@ -101,13 +75,13 @@ function run_python() {
     fi
 
     local script_name=$(basename "$1" .py)
-    local output_dir="traces/$script_name"
+    local output_dir="traces/$gpu_name/$nb_gpus/$script_name"
 
     if [ $OUTPUT_FOLDER_ARGS -eq 1 ]; then
         local args=$(echo "${@:2}" | tr ' ' '_')
         # Remove characters '/' and '-' from folder name
         args=$(echo "$args" | tr -d '/-')
-        output_dir="traces/$script_name/$args"
+        output_dir="traces/$gpu_name/$nb_gpus/$script_name/$args"
     fi
 
     mkdir -p "$output_dir"
@@ -116,6 +90,7 @@ function run_python() {
 }
 
 
+
 # run or profile
 
 function slaunch() {
@@ -136,13 +111,9 @@ export TMPDIR=$JOBSCRATCH
 # faire pointer le répertoire /tmp/nvidia vers TMPDIR
 ln -s $JOBSCRATCH /tmp/nvidia
 
-
-
-
 # mpch=(128 256 512 1024 2048 4096)
-grid=(256 512 1024 2048 4096)
+grid=(256 512 1024 2048 4096 8192)
 precisions=(float32 float64)
-pdim="${pdims_table[$nb_gpus]}"
 solvers=(lpt lfm)
 
 # GPU name is a100 if num_gpu_per_node is 8, otherwise it is v100
@@ -154,17 +125,23 @@ else
 fi
 
 out_dir="pm_prof/$gpu_name/$nb_gpus"
-
-echo "Output dir is  : $out_dir"
-
+trace_dir="traces/$gpu_name/$nb_gpus/bench_pmwd"
+echo "Output dir is  : $out_dir" 
+echo "Trace dir is  : $trace_dir"
 
 for pr in "${precisions[@]}"; do
     for g in "${grid[@]}"; do
         for solver in "${solvers[@]}"; do
            slaunch bench_pmwd.py -m $g -b $g -pr $pr -s $solver -i 4 -o $out_dir -f  
         done
+        # delete crash core dump files
+        rm -f core.python.*
     done
 done
 
-
-
+# # zip the output files and traces
+# tar -czvf $out_dir.tar.gz $out_dir
+# tar -czvf $trace_dir.tar.gz $trace_dir
+# # remove the output files and traces
+# rm -rf $out_dir $trace_dir
+#
diff --git a/benchmarks/run_all_jobs.sh b/benchmarks/run_all_jobs.sh
index cfee491..2e6cb3f 100755
--- a/benchmarks/run_all_jobs.sh
+++ b/benchmarks/run_all_jobs.sh
@@ -1,19 +1,21 @@
 #!/bin/bash
 # Run all slurms jobs
-nodes_v100=(1 2 4 8 16)
-nodes_a100=(1 2 4 8 16)
+nodes_v100=(1 2 4 8 16 32)
+nodes_a100=(1 2 4 8 16 32)
 
 
 for n in ${nodes_v100[@]}; do
-    sbatch --nodes=$n  --job-name=v100_$n-JAXPM particle_mesh_v100.slurm
+    sbatch --account=tkc@v100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C v100-32g --job-name=JAXPM-$n-N-v100 particle_mesh.slurm
 done
 
 for n in ${nodes_a100[@]}; do
-    sbatch --nodes=$n --job-name=a100_$n-JAXPM particle_mesh_a100.slurm
+    sbatch --account=tkc@a100 --nodes=$n --gres=gpu:4 --tasks-per-node=4 -C a100     --job-name=JAXPM-$n-N-a100 particle_mesh.slurm
 done
 
 # single GPUs
-sbatch --job-name=JAXPM-1GPU-V100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_v100.slurm
-sbatch --job-name=JAXPM-1GPU-A100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 particle_mesh_a100.slurm
-sbatch --job-name=PMWD-v100 pmwd_v100.slurm
-sbatch --job-name=PMWD-a100 pmwd_a100.slurm
+sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100     --job-name=JAXPM-1GPU-V100 particle_mesh.slurm
+sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=JAXPM-1GPU-A100 particle_mesh.slurm
+sbatch --account=tkc@a100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C a100     --job-name=PMWD-1GPU-v100 pmwd_pm.slurm
+sbatch --account=tkc@v100 --nodes=1 --gres=gpu:1 --tasks-per-node=1 -C v100-32g --job-name=PMWD-1GPU-a100 pmwd_pm.slurm
+
+
diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py
index 9fb0e15..c16bc33 100644
--- a/jaxpm/distributed.py
+++ b/jaxpm/distributed.py
@@ -88,8 +88,8 @@ def get_halo_size(halo_size):
 
 def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
     mesh = mesh_lib.thread_resources.env.physical_mesh
-    if distributed and not (mesh.empty) and (halo_extents[0] > 0
-                                             or halo_extents[1] > 0):
+    if distributed and not (mesh.empty) and (halo_extents > 0
+                                             or halo_extents > 0):
         return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
     else:
         return x
diff --git a/jaxpm/painting.py b/jaxpm/painting.py
index 7160913..7d9e9fa 100644
--- a/jaxpm/painting.py
+++ b/jaxpm/painting.py
@@ -50,15 +50,15 @@ def cic_paint_impl(mesh, displacement, weight=None):
 @partial(jax.jit, static_argnums=(2, ))
 def cic_paint(mesh, positions, halo_size=0, weight=None):
 
-    halo_size, halo_extents = get_halo_size(halo_size)
-    mesh = slice_pad(mesh, halo_size)
+    halo_padding, halo_extents = get_halo_size(halo_size)
+    mesh = slice_pad(mesh, halo_padding)
     mesh = autoshmap(cic_paint_impl,
                      in_specs=(P('x', 'y'), P('x', 'y'), P()),
                      out_specs=P('x', 'y'))(mesh, positions, weight)
     mesh = halo_exchange(mesh,
-                         halo_extents=halo_extents,
-                         halo_periods=(True, True, True))
-    mesh = slice_unpad(mesh, halo_size)
+                         halo_extents=halo_size // 2,
+                         halo_periods=True)
+    mesh = slice_unpad(mesh, halo_padding)
     return mesh
 
 
@@ -95,11 +95,11 @@ def cic_read_impl(mesh, displacement):
 @partial(jax.jit, static_argnums=(2, ))
 def cic_read(mesh, displacement, halo_size=0):
 
-    halo_size, halo_extents = get_halo_size(halo_size)
-    mesh = slice_pad(mesh, halo_size)
+    halo_padding, halo_extents = get_halo_size(halo_size)
+    mesh = slice_pad(mesh, halo_padding)
     mesh = halo_exchange(mesh,
-                         halo_extents=halo_extents,
-                         halo_periods=(True, True, True))
+                         halo_extents=halo_size//2,
+                         halo_periods=True)
     displacement = autoshmap(cic_read_impl,
                              in_specs=(P('x', 'y'), P('x', 'y')),
                              out_specs=P('x', 'y'))(mesh, displacement)
@@ -160,16 +160,16 @@ def cic_paint_dx_impl(displacements, halo_size):
 @partial(jax.jit, static_argnums=(1, ))
 def cic_paint_dx(displacements, halo_size=0):
     
-    halo_size, halo_extents = get_halo_size(halo_size)
+    halo_padding, halo_extents = get_halo_size(halo_size)
                 
-    mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size),
+    mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_padding),
                      in_specs=(P('x', 'y')),
                      out_specs=P('x', 'y'))(displacements)
                           
     mesh = halo_exchange(mesh,
-                         halo_extents=halo_extents,
-                         halo_periods=(True, True, True))
-    mesh = slice_unpad(mesh, halo_size)
+                         halo_extents=halo_size//2,
+                         halo_periods=True)
+    mesh = slice_unpad(mesh, halo_padding)
     return mesh
 
 
@@ -194,12 +194,12 @@ def cic_read_dx_impl(mesh ,  halo_size):
 @partial(jax.jit, static_argnums=(1, ))
 def cic_read_dx(mesh, halo_size=0):
     # return mesh
-    halo_size, halo_extents = get_halo_size(halo_size)
-    mesh = slice_pad(mesh, halo_size)
+    halo_padding, halo_extents = get_halo_size(halo_size)
+    mesh = slice_pad(mesh, halo_padding)
     mesh = halo_exchange(mesh,
-                         halo_extents=halo_extents,
-                         halo_periods=(True, True, True))
-    displacements = autoshmap(partial(cic_read_dx_impl ,  halo_size=halo_size),
+                         halo_extents=halo_size//2,
+                         halo_periods=True)
+    displacements = autoshmap(partial(cic_read_dx_impl ,  halo_size=halo_padding),
                               in_specs=(P('x', 'y')),
                               out_specs=P('x', 'y'))(mesh)
 
diff --git a/scripts/particle_mesh.slurm b/scripts/particle_mesh.slurm
index 51332a5..e8924d1 100644
--- a/scripts/particle_mesh.slurm
+++ b/scripts/particle_mesh.slurm
@@ -62,8 +62,8 @@ module load nvidia-nsight-systems/2024.1.1.59
 echo "The number of nodes allocated for this job is: $num_nodes"
 echo "The number of GPUs allocated for this job is: $nb_gpus"
 
-export EQX_ON_ERROR=nan
-export CUDA_ALLOC=1
+export ENABLE_PERFO_STEP=NVTX
+export MPI4JAX_USE_CUDA_MPI=1
 
 function profile_python() {
     if [ $# -lt 1 ]; then
@@ -122,6 +122,7 @@ set -x
 
 declare -A pdims_table
 # Define the table
+pdims_table[1]="1x1"
 pdims_table[4]="2x2 1x4"
 pdims_table[8]="2x4 1x8"
 pdims_table[16]="2x8 1x16"