kernels.py no longer uses global mesh

This commit is contained in:
Wassim KABALAN 2024-10-22 11:05:21 -04:00
parent 591ee32c55
commit a5b267bd63

View file

@ -3,7 +3,6 @@ from enum import Enum
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
import numpy as np import numpy as np
from jax._src import mesh as mesh_lib
from jax.lib.xla_client import FftType from jax.lib.xla_client import FftType
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs from jaxdecomp import fftfreq3d, get_output_specs
@ -11,7 +10,6 @@ from jaxdecomp import fftfreq3d, get_output_specs
from jaxpm.distributed import autoshmap from jaxpm.distributed import autoshmap
def fftk(k_array): def fftk(k_array):
""" """
Generate Fourier transform wave numbers for a given mesh. Generate Fourier transform wave numbers for a given mesh.
@ -28,15 +26,19 @@ def fftk(k_array):
return kx, ky, kz return kx, ky, kz
def interpolate_power_spectrum(input, k, pk): def interpolate_power_spectrum(input, k, pk, sharding=None):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape) ).reshape(x.shape)
specs = P('x', 'y')
mesh = mesh_lib.thread_resources.env.physical_mesh
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh))
return autoshmap(pk_fn, in_specs=out_specs, out_specs=out_specs)(input) gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
out_specs = P(*get_output_specs(FftType.FFT, specs, mesh=gpu_mesh))
return autoshmap(pk_fn,
gpu_mesh=gpu_mesh,
in_specs=out_specs,
out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1): def gradient_kernel(kvec, direction, order=1):