mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
revert single halo extent change
This commit is contained in:
parent
ab86699c88
commit
afecb13cde
2 changed files with 22 additions and 34 deletions
|
@ -86,13 +86,13 @@ def get_halo_size(halo_size):
|
|||
|
||||
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
|
||||
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
|
||||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext, 0))
|
||||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
||||
|
||||
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (halo_extents > 0
|
||||
or halo_extents > 0):
|
||||
if distributed and not (mesh.empty) and (halo_extents[0] > 0
|
||||
or halo_extents[1] > 0):
|
||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||
else:
|
||||
return x
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue