Return normal order frequencies for single GPU

This commit is contained in:
Wassim KABALAN 2024-07-18 12:44:18 +02:00
parent c81d4d2336
commit abde5439f6

View file

@ -3,6 +3,7 @@ from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax._src import mesh as mesh_lib
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap
@ -28,7 +29,12 @@ def fftk(shape, dtype=np.float32):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
if not mesh_lib.thread_resources.env.physical_mesh.empty:
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
else:
kx, ky, kz = get_kvec(kx, ky, kz) # The order corresponds
# to the order of dimensions in the transposed FFT
return kx, ky, kz