mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-16 16:10:54 +00:00
Return normal order frequencies for single GPU
This commit is contained in:
parent
c81d4d2336
commit
abde5439f6
1 changed files with 7 additions and 1 deletions
|
@ -3,6 +3,7 @@ from functools import partial
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from jax._src import mesh as mesh_lib
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
from jaxpm.distributed import autoshmap
|
from jaxpm.distributed import autoshmap
|
||||||
|
@ -28,7 +29,12 @@ def fftk(shape, dtype=np.float32):
|
||||||
return (ky.reshape([-1, 1, 1]),
|
return (ky.reshape([-1, 1, 1]),
|
||||||
kz.reshape([1, -1, 1]),
|
kz.reshape([1, -1, 1]),
|
||||||
kx.reshape([1, 1, -1])) # yapf: disable
|
kx.reshape([1, 1, -1])) # yapf: disable
|
||||||
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
|
||||||
|
if not mesh_lib.thread_resources.env.physical_mesh.empty:
|
||||||
|
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
||||||
|
else:
|
||||||
|
kx, ky, kz = get_kvec(kx, ky, kz) # The order corresponds
|
||||||
|
|
||||||
# to the order of dimensions in the transposed FFT
|
# to the order of dimensions in the transposed FFT
|
||||||
return kx, ky, kz
|
return kx, ky, kz
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue