merge with JZ

This commit is contained in:
Wassim KABALAN 2024-10-18 14:59:40 -04:00
commit ab86699c88
7 changed files with 226 additions and 131 deletions

View file

@ -91,8 +91,8 @@ def get_halo_size(halo_size):
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (halo_extents[0] > 0
or halo_extents[1] > 0):
if distributed and not (mesh.empty) and (halo_extents > 0
or halo_extents > 0):
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
else:
return x