mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Update for jaxDecomp pure JAX
This commit is contained in:
parent
831291c1f9
commit
2ea05a1cd6
9 changed files with 214 additions and 532 deletions
|
@ -88,8 +88,8 @@ def get_halo_size(halo_size):
|
|||
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (halo_extents[0] > 0
|
||||
or halo_extents[1] > 0):
|
||||
if distributed and not (mesh.empty) and (halo_extents > 0
|
||||
or halo_extents > 0):
|
||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||
else:
|
||||
return x
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue