mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-19 01:20:55 +00:00
add a strict dependency on jaxdecomp
This commit is contained in:
parent
a160a3faa9
commit
38714cf65d
1 changed files with 4 additions and 19 deletions
|
@ -3,13 +3,6 @@ from typing import Any, Callable, Hashable
|
||||||
Specs = Any
|
Specs = Any
|
||||||
AxisName = Hashable
|
AxisName = Hashable
|
||||||
|
|
||||||
try:
|
|
||||||
import jaxdecomp
|
|
||||||
distributed = True
|
|
||||||
except ImportError:
|
|
||||||
print("jaxdecomp not installed. Distributed functions will not work.")
|
|
||||||
distributed = False
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
@ -48,17 +41,11 @@ def autoshmap(
|
||||||
|
|
||||||
|
|
||||||
def fft3d(x):
|
def fft3d(x):
|
||||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
|
||||||
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||||
else:
|
|
||||||
return jnp.fft.fftn(x.astype(jnp.complex64))
|
|
||||||
|
|
||||||
|
|
||||||
def ifft3d(x):
|
def ifft3d(x):
|
||||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
|
||||||
return jaxdecomp.pifft3d(x).real
|
return jaxdecomp.pifft3d(x).real
|
||||||
else:
|
|
||||||
return jnp.fft.ifftn(x).real
|
|
||||||
|
|
||||||
|
|
||||||
def get_halo_size(halo_size):
|
def get_halo_size(halo_size):
|
||||||
|
@ -77,10 +64,8 @@ def get_halo_size(halo_size):
|
||||||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
||||||
|
|
||||||
|
|
||||||
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
|
||||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
if (halo_extents[0] > 0 or halo_extents[1] > 0):
|
||||||
if distributed and not (mesh.empty) and (halo_extents[0] > 0
|
|
||||||
or halo_extents[1] > 0):
|
|
||||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||||
else:
|
else:
|
||||||
return x
|
return x
|
||||||
|
|
Loading…
Add table
Reference in a new issue