mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
Adjust pencil type for frequencies
This commit is contained in:
parent
1f2035176f
commit
f25eb7d465
2 changed files with 38 additions and 3 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -98,6 +98,11 @@ __pypackages__/
|
||||||
celerybeat-schedule
|
celerybeat-schedule
|
||||||
celerybeat.pid
|
celerybeat.pid
|
||||||
|
|
||||||
|
|
||||||
|
out
|
||||||
|
traces
|
||||||
|
*.npy
|
||||||
|
*.out
|
||||||
# SageMath parsed files
|
# SageMath parsed files
|
||||||
*.sage.py
|
*.sage.py
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,29 @@ from jax._src import mesh as mesh_lib
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
from jaxpm.distributed import autoshmap
|
from jaxpm.distributed import autoshmap
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
class PencilType(Enum):
|
||||||
|
NO_DECOMP = 0
|
||||||
|
SLAB_XY = 1
|
||||||
|
SLAB_YZ = 2
|
||||||
|
PENCILS = 3
|
||||||
|
|
||||||
|
def get_pencil_type():
|
||||||
|
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||||
|
if mesh.empty:
|
||||||
|
pdims = None
|
||||||
|
else:
|
||||||
|
pdims = mesh.devices.shape[::-1]
|
||||||
|
|
||||||
|
if pdims == (1, 1) or pdims == None:
|
||||||
|
return PencilType.NO_DECOMP
|
||||||
|
elif pdims[0] == 1:
|
||||||
|
return PencilType.SLAB_XY
|
||||||
|
elif pdims[1] == 1:
|
||||||
|
return PencilType.SLAB_YZ
|
||||||
|
else:
|
||||||
|
return PencilType.PENCILS
|
||||||
|
|
||||||
def fftk(shape, dtype=np.float32):
|
def fftk(shape, dtype=np.float32):
|
||||||
"""
|
"""
|
||||||
|
@ -30,10 +52,18 @@ def fftk(shape, dtype=np.float32):
|
||||||
kz.reshape([1, -1, 1]),
|
kz.reshape([1, -1, 1]),
|
||||||
kx.reshape([1, 1, -1])) # yapf: disable
|
kx.reshape([1, 1, -1])) # yapf: disable
|
||||||
|
|
||||||
if not mesh_lib.thread_resources.env.physical_mesh.empty:
|
pencil_type = get_pencil_type()
|
||||||
ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds
|
# YZ returns Y pencil
|
||||||
|
# XY and pencils returns a Z pencil
|
||||||
|
# NO_DECOMP returns a X pencil
|
||||||
|
if pencil_type == PencilType.NO_DECOMP:
|
||||||
|
kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil
|
||||||
|
elif pencil_type == PencilType.SLAB_YZ:
|
||||||
|
kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil
|
||||||
|
elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS:
|
||||||
|
ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil
|
||||||
else:
|
else:
|
||||||
kx, ky, kz = get_kvec(kx, ky, kz) # The order corresponds
|
raise ValueError("Unknown pencil type")
|
||||||
|
|
||||||
# to the order of dimensions in the transposed FFT
|
# to the order of dimensions in the transposed FFT
|
||||||
return kx, ky, kz
|
return kx, ky, kz
|
||||||
|
|
Loading…
Add table
Reference in a new issue