From 01b952701edc16baa5f7becfbdb624cd764e1375 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 21 Oct 2024 13:55:48 -0400 Subject: [PATCH] update for latest jaxDecomp --- jaxpm/kernels.py | 64 ++++++++---------------------------------------- jaxpm/pm.py | 11 +++++---- 2 files changed, 16 insertions(+), 59 deletions(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index d954132..fabe3a2 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,40 +1,18 @@ from enum import Enum -from functools import partial import jax.numpy as jnp import jax_cosmo as jc import numpy as np from jax._src import mesh as mesh_lib +from jax.lib.xla_client import FftType from jax.sharding import PartitionSpec as P +from jaxdecomp import fftfreq3d, get_output_specs from jaxpm.distributed import autoshmap -class PencilType(Enum): - NO_DECOMP = 0 - SLAB_XY = 1 - SLAB_YZ = 2 - PENCILS = 3 - -def get_pencil_type(): - mesh = mesh_lib.thread_resources.env.physical_mesh - if mesh.empty: - pdims = None - else: - pdims = mesh.devices.shape[::-1] - - if pdims == (1, 1) or pdims == None: - return PencilType.NO_DECOMP - elif pdims[0] == 1: - return PencilType.SLAB_XY - elif pdims[1] == 1: - return PencilType.SLAB_YZ - else: - return PencilType.PENCILS - - -def fftk(shape, dtype=np.float32): +def fftk(k_array): """ Generate Fourier transform wave numbers for a given mesh. @@ -44,31 +22,8 @@ def fftk(shape, dtype=np.float32): Returns: list: List of wave number arrays for each dimension in the order [kx, ky, kz]. - """ - kx, ky, kz = [jnp.fft.fftfreq(s, dtype=dtype) * 2 * np.pi for s in shape] - - @partial(autoshmap, - in_specs=(P('x'), P('y'), P(None)), - out_specs=(P('x'), P(None, 'y'), P(None)), - in_fourrier_space=True) - def get_kvec(ky, kz, kx): - return (ky.reshape([-1, 1, 1]), - kz.reshape([1, -1, 1]), - kx.reshape([1, 1, -1])) # yapf: disable - - pencil_type = get_pencil_type() - # YZ returns Y pencil - # XY and pencils returns a Z pencil - # NO_DECOMP returns a X pencil - if pencil_type == PencilType.NO_DECOMP: - kx, ky, kz = get_kvec(kx, ky, kz) # Z Y X ==> X pencil - elif pencil_type == PencilType.SLAB_YZ: - kz, kx, ky = get_kvec(kz, kx, ky) # X Z Y ==> Y pencil - elif pencil_type == PencilType.SLAB_XY or pencil_type == PencilType.PENCILS: - ky, kz, kx = get_kvec(ky, kz, kx) # Z X Y ==> Z pencil - else: - raise ValueError("Unknown pencil type") - + """ + kx, ky, kz = fftfreq3d(k_array) # to the order of dimensions in the transposed FFT return kx, ky, kz @@ -77,10 +32,11 @@ def interpolate_power_spectrum(input, k, pk): pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk ).reshape(x.shape) - return autoshmap(pk_fn, - in_specs=P('x', 'y'), - out_specs=P('x', 'y'), - in_fourrier_space=True)(input) + specs = P('x', 'y') + mesh = mesh_lib.thread_resources.env.physical_mesh + out_specs = P(*get_output_specs(FftType.FFT, specs, mesh)) + + return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input) def gradient_kernel(kvec, direction, order=1): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 377df8e..090bd2b 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -22,13 +22,13 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): assert (delta is not None ), "If mesh_shape is not provided, delta should be provided" mesh_shape = delta.shape - kvec = fftk(mesh_shape) if delta is None: delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size)) else: delta_k = fft3d(delta) + kvec = fftk(delta_k) # Computes gravitational potential pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) @@ -137,15 +137,16 @@ def linear_field(mesh_shape, box_size, pk, seed): """ Generate initial conditions. """ - kvec = fftk(mesh_shape) + # Initialize a random field with one slice on each gpu + field = normal_field(mesh_shape, seed=seed) + field = fft3d(field) + kvec = fftk(field) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / ( box_size[0] * box_size[1] * box_size[2]) - # Initialize a random field with one slice on each gpu - field = normal_field(mesh_shape, seed=seed) - field = fft3d(field) * pkmesh**0.5 + field = field * (pkmesh)**0.5 field = ifft3d(field) return field