mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
add a strict dependency on jaxdecomp
This commit is contained in:
parent
a160a3faa9
commit
38714cf65d
1 changed files with 4 additions and 19 deletions
|
@ -3,13 +3,6 @@ from typing import Any, Callable, Hashable
|
|||
Specs = Any
|
||||
AxisName = Hashable
|
||||
|
||||
try:
|
||||
import jaxdecomp
|
||||
distributed = True
|
||||
except ImportError:
|
||||
print("jaxdecomp not installed. Distributed functions will not work.")
|
||||
distributed = False
|
||||
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
|
@ -48,17 +41,11 @@ def autoshmap(
|
|||
|
||||
|
||||
def fft3d(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||
else:
|
||||
return jnp.fft.fftn(x.astype(jnp.complex64))
|
||||
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||
|
||||
|
||||
def ifft3d(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.pifft3d(x).real
|
||||
else:
|
||||
return jnp.fft.ifftn(x).real
|
||||
return jaxdecomp.pifft3d(x).real
|
||||
|
||||
|
||||
def get_halo_size(halo_size):
|
||||
|
@ -77,10 +64,8 @@ def get_halo_size(halo_size):
|
|||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
||||
|
||||
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (halo_extents[0] > 0
|
||||
or halo_extents[1] > 0):
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
|
||||
if (halo_extents[0] > 0 or halo_extents[1] > 0):
|
||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||
else:
|
||||
return x
|
||||
|
|
Loading…
Add table
Reference in a new issue