Adjust pencil type for frequencies

This commit is contained in:
Wassim KABALAN 2024-07-28 13:44:03 +02:00
parent 1f2035176f
commit f25eb7d465
2 changed files with 38 additions and 3 deletions

5
.gitignore vendored
View file

@ -98,6 +98,11 @@ __pypackages__/
celerybeat-schedule
celerybeat.pid
out
traces
*.npy
*.out
# SageMath parsed files
*.sage.py

View file

@ -7,7 +7,29 @@ from jax._src import mesh as mesh_lib
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import autoshmap
from enum import Enum
class PencilType(Enum):
NO_DECOMP = 0
SLAB_XY = 1
SLAB_YZ = 2
PENCILS = 3
def get_pencil_type():
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
pdims = None
else:
pdims = mesh.devices.shape[::-1]
if pdims == (1, 1) or pdims == None:
return PencilType.NO_DECOMP
elif pdims[0] == 1:
return PencilType.SLAB_XY
elif pdims[1] == 1:
return PencilType.SLAB_YZ
else:
return PencilType.PENCILS
def fftk(shape, dtype=np.float32):
"""
@ -30,10 +52,18 @@ def fftk(shape, dtype=np.float32):
kz.reshape([1, -1, 1]),
kx.reshape([1, 1, -1])) # yapf: disable
if not mesh_lib.thread_resources.env.physical_mesh.empty:
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
pencil_type = get_pencil_type()
# YZ returns Y pencil
# XY and pencils returns a Z pencil
# NO_DECOMP returns a X pencil
if pencil_type == PencilType.NO_DECOMP:
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
elif pencil_type == PencilType.SLAB_YZ:
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS:
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
else:
kx, ky, kz = get_kvec(kx, ky, kz) # The order corresponds
raise ValueError("Unknown pencil type")
# to the order of dimensions in the transposed FFT
return kx, ky, kz