kernels.py no longer uses global mesh

This commit is contained in:
Wassim KABALAN 2024-10-22 11:05:21 -04:00
parent 591ee32c55
commit a5b267bd63

View file

@ -3,7 +3,6 @@ from enum import Enum
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax._src import mesh as mesh_lib
from jax.lib.xla_client import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
@ -11,7 +10,6 @@ from jaxdecomp import fftfreq3d, get_output_specs
from jaxpm.distributed import autoshmap
def fftk(k_array):
"""
Generate Fourier transform wave numbers for a given mesh.
@ -28,31 +26,35 @@ def fftk(k_array):
return kx, ky, kz
def interpolate_power_spectrum(input, k, pk):
def interpolate_power_spectrum(input, k, pk, sharding=None):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape)
specs = P('x', 'y')
mesh = mesh_lib.thread_resources.env.physical_mesh
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh))
return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input)
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh))
return autoshmap(pk_fn,
gpu_mesh=gpu_mesh,
in_specs=out_specs,
out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1):
"""
Computes the gradient kernel in the requested direction
Parameters:
-----------
kvec: array
Array of k values in Fourier space
direction: int
Index of the direction in which to take the gradient
Returns:
--------
wts: array
Complex kernel
"""
Computes the gradient kernel in the requested direction
Parameters:
-----------
kvec: array
Array of k values in Fourier space
direction: int
Index of the direction in which to take the gradient
Returns:
--------
wts: array
Complex kernel
"""
if order == 0:
wts = 1j * kvec[direction]
wts = jnp.squeeze(wts)
@ -68,16 +70,16 @@ def gradient_kernel(kvec, direction, order=1):
def laplace_kernel(kvec):
"""
Compute the Laplace kernel from a given K vector
Parameters:
-----------
kvec: array
Array of k values in Fourier space
Returns:
--------
wts: array
Complex kernel
"""
Compute the Laplace kernel from a given K vector
Parameters:
-----------
kvec: array
Array of k values in Fourier space
Returns:
--------
wts: array
Complex kernel
"""
kk = sum(ki**2 for ki in kvec)
wts = jnp.where(kk == 0, 1., 1. / kk)
return wts