mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
Return normal order frequencies for single GPU
This commit is contained in:
parent
c81d4d2336
commit
abde5439f6
1 changed files with 7 additions and 1 deletions
|
@ -3,6 +3,7 @@ from functools import partial
|
|||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import autoshmap
|
||||
|
@ -28,7 +29,12 @@ def fftk(shape, dtype=np.float32):
|
|||
return (ky.reshape([-1, 1, 1]),
|
||||
kz.reshape([1, -1, 1]),
|
||||
kx.reshape([1, 1, -1])) # yapf: disable
|
||||
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
||||
|
||||
if not mesh_lib.thread_resources.env.physical_mesh.empty:
|
||||
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
||||
else:
|
||||
kx, ky, kz = get_kvec(kx, ky, kz) # The order corresponds
|
||||
|
||||
# to the order of dimensions in the transposed FFT
|
||||
return kx, ky, kz
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue