revert single halo extent change

This commit is contained in:
Wassim KABALAN 2024-10-20 09:58:24 -04:00
parent ab86699c88
commit afecb13cde
2 changed files with 22 additions and 34 deletions

View file

@ -86,13 +86,13 @@ def get_halo_size(halo_size):
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext, 0))
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (halo_extents > 0
or halo_extents > 0):
if distributed and not (mesh.empty) and (halo_extents[0] > 0
or halo_extents[1] > 0):
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
else:
return x