From a5b267bd63609b9290a5c633baffc0bb177fd431 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 11:05:21 -0400 Subject: [PATCH] kernels.py no longer uses global mesh --- jaxpm/kernels.py | 60 +++++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index fabe3a2..8123ad5 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -3,7 +3,6 @@ from enum import Enum import jax.numpy as jnp import jax_cosmo as jc import numpy as np -from jax._src import mesh as mesh_lib from jax.lib.xla_client import FftType from jax.sharding import PartitionSpec as P from jaxdecomp import fftfreq3d, get_output_specs @@ -11,7 +10,6 @@ from jaxdecomp import fftfreq3d, get_output_specs from jaxpm.distributed import autoshmap - def fftk(k_array): """ Generate Fourier transform wave numbers for a given mesh. @@ -28,31 +26,35 @@ def fftk(k_array): return kx, ky, kz -def interpolate_power_spectrum(input, k, pk): +def interpolate_power_spectrum(input, k, pk, sharding=None): pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk ).reshape(x.shape) - specs = P('x', 'y') - mesh = mesh_lib.thread_resources.env.physical_mesh - out_specs = P(*get_output_specs(FftType.FFT, specs, mesh)) - return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input) + gpu_mesh = sharding.mesh if sharding is not None else None + specs = sharding.spec if sharding is not None else P() + out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh)) + + return autoshmap(pk_fn, + gpu_mesh=gpu_mesh, + in_specs=out_specs, + out_specs=out_specs)(input) def gradient_kernel(kvec, direction, order=1): """ - Computes the gradient kernel in the requested direction - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - direction: int - Index of the direction in which to take the gradient - Returns: - -------- - wts: array - Complex kernel - """ + Computes the gradient kernel in the requested direction + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + direction: int + Index of the direction in which to take the gradient + Returns: + -------- + wts: array + Complex kernel + """ if order == 0: wts = 1j * kvec[direction] wts = jnp.squeeze(wts) @@ -68,16 +70,16 @@ def gradient_kernel(kvec, direction, order=1): def laplace_kernel(kvec): """ - Compute the Laplace kernel from a given K vector - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - Returns: - -------- - wts: array - Complex kernel - """ + Compute the Laplace kernel from a given K vector + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + Returns: + -------- + wts: array + Complex kernel + """ kk = sum(ki**2 for ki in kvec) wts = jnp.where(kk == 0, 1., 1. / kk) return wts