diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index a4d83d6..97ac39e 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,6 +1,7 @@ from functools import partial import jax.numpy as jnp +import jax_cosmo as jc import numpy as np from jax.sharding import PartitionSpec as P @@ -32,6 +33,13 @@ def fftk(shape, dtype=np.float32): return kx, ky, kz +def interpolate_power_spectrum(input, k, pk): + + pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk + ).reshape(x.shape) + return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'))(input) + + def gradient_kernel(kvec, direction, order=1): """ Computes the gradient kernel in the requested direction