diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index d333b8c..235ddec 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -26,8 +26,7 @@ def fftk(k_array): def interpolate_power_spectrum(input, k, pk, sharding=None): - pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk - ).reshape(x.shape) + pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape) gpu_mesh = sharding.mesh if sharding is not None else None specs = sharding.spec if sharding is not None else P()