From ff8856d2bc04ee17c567bb3c7c16e856470a8506 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Sat, 26 Oct 2024 18:53:11 +0200 Subject: [PATCH] use jnp interp instead of jc interp --- jaxpm/kernels.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index d333b8c..235ddec 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -26,8 +26,7 @@ def fftk(k_array): def interpolate_power_spectrum(input, k, pk, sharding=None): - pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk - ).reshape(x.shape) + pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape) gpu_mesh = sharding.mesh if sharding is not None else None specs = sharding.spec if sharding is not None else P()