add a strict dependency on jaxdecomp

This commit is contained in:
Wassim KABALAN 2024-10-22 11:02:55 -04:00
parent a160a3faa9
commit 38714cf65d

View file

@ -3,13 +3,6 @@ from typing import Any, Callable, Hashable
Specs = Any
AxisName = Hashable
try:
import jaxdecomp
distributed = True
except ImportError:
print("jaxdecomp not installed. Distributed functions will not work.")
distributed = False
from functools import partial
import jax
@ -48,17 +41,11 @@ def autoshmap(
def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
else:
return jnp.fft.fftn(x.astype(jnp.complex64))
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
def ifft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pifft3d(x).real
else:
return jnp.fft.ifftn(x).real
return jaxdecomp.pifft3d(x).real
def get_halo_size(halo_size):
@ -77,10 +64,8 @@ def get_halo_size(halo_size):
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
mesh = mesh_lib.thread_resources.env.physical_mesh
if distributed and not (mesh.empty) and (halo_extents[0] > 0
or halo_extents[1] > 0):
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
if (halo_extents[0] > 0 or halo_extents[1] > 0):
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
else:
return x