Update for jaxDecomp pure JAX

This commit is contained in:
Wassim KABALAN 2024-08-07 23:52:13 +02:00
parent 831291c1f9
commit 2ea05a1cd6
9 changed files with 214 additions and 532 deletions

View file

@ -88,8 +88,8 @@ def get_halo_size(halo_size):
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (halo_extents[0] > 0
or halo_extents[1] > 0):
if distributed and not (mesh.empty) and (halo_extents > 0
or halo_extents > 0):
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
else:
return x