diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index f4fad8a..35d5e8b 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -3,13 +3,6 @@ from typing import Any, Callable, Hashable Specs = Any AxisName = Hashable -try: - import jaxdecomp - distributed = True -except ImportError: - print("jaxdecomp not installed. Distributed functions will not work.") - distributed = False - from functools import partial import jax @@ -48,17 +41,11 @@ def autoshmap( def fft3d(x): - if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): - return jaxdecomp.pfft3d(x.astype(jnp.complex64)) - else: - return jnp.fft.fftn(x.astype(jnp.complex64)) + return jaxdecomp.pfft3d(x.astype(jnp.complex64)) def ifft3d(x): - if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): - return jaxdecomp.pifft3d(x).real - else: - return jnp.fft.ifftn(x).real + return jaxdecomp.pifft3d(x).real def get_halo_size(halo_size): @@ -77,10 +64,8 @@ def get_halo_size(halo_size): return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext)) -def halo_exchange(x, halo_extents, halo_periods=(True, True, True)): - mesh = mesh_lib.thread_resources.env.physical_mesh - if distributed and not (mesh.empty) and (halo_extents[0] > 0 - or halo_extents[1] > 0): +def halo_exchange(x, halo_extents, halo_periods=(True, True)): + if (halo_extents[0] > 0 or halo_extents[1] > 0): return jaxdecomp.halo_exchange(x, halo_extents, halo_periods) else: return x