diff --git a/.gitignore b/.gitignore index b6e4761..baef139 100644 --- a/.gitignore +++ b/.gitignore @@ -98,6 +98,11 @@ __pypackages__/ celerybeat-schedule celerybeat.pid + +out +traces +*.npy +*.out # SageMath parsed files *.sage.py diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 176e1c3..0025fa8 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -7,7 +7,29 @@ from jax._src import mesh as mesh_lib from jax.sharding import PartitionSpec as P from jaxpm.distributed import autoshmap +from enum import Enum +class PencilType(Enum): + NO_DECOMP = 0 + SLAB_XY = 1 + SLAB_YZ = 2 + PENCILS = 3 + +def get_pencil_type(): + mesh = mesh_lib.thread_resources.env.physical_mesh + if mesh.empty: + pdims = None + else: + pdims = mesh.devices.shape[::-1] + + if pdims == (1, 1) or pdims == None: + return PencilType.NO_DECOMP + elif pdims[0] == 1: + return PencilType.SLAB_XY + elif pdims[1] == 1: + return PencilType.SLAB_YZ + else: + return PencilType.PENCILS def fftk(shape, dtype=np.float32): """ @@ -30,10 +52,18 @@ def fftk(shape, dtype=np.float32): kz.reshape([1, -1, 1]), kx.reshape([1, 1, -1])) # yapf: disable - if not mesh_lib.thread_resources.env.physical_mesh.empty: - ky, kz, kx = get_kvec(ky, kz, kx) # The order corresponds + pencil_type = get_pencil_type() + # YZ returns Y pencil + # XY and pencils returns a Z pencil + # NO_DECOMP returns a X pencil + if pencil_type == PencilType.NO_DECOMP: + kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil + elif pencil_type == PencilType.SLAB_YZ: + kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil + elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS: + ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil else: - kx, ky, kz = get_kvec(kx, ky, kz) # The order corresponds + raise ValueError("Unknown pencil type") # to the order of dimensions in the transposed FFT return kx, ky, kz