diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 97ac39e..176e1c3 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -3,6 +3,7 @@ from functools import partial import jax.numpy as jnp import jax_cosmo as jc import numpy as np +from jax._src import mesh as mesh_lib from jax.sharding import PartitionSpec as P from jaxpm.distributed import autoshmap @@ -28,7 +29,12 @@ def fftk(shape, dtype=np.float32): return (ky.reshape([-1, 1, 1]), kz.reshape([1, -1, 1]), kx.reshape([1, 1, -1])) # yapf: disable - ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds + + if not mesh_lib.thread_resources.env.physical_mesh.empty: + ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds + else: + kx, ky, kz = get_kvec(kx, ky, kz) # The order corresponds + # to the order of dimensions in the transposed FFT return kx, ky, kz