use jnp interp instead of jc interp

This commit is contained in:
Wassim KABALAN 2024-10-26 18:53:11 +02:00
parent d2f1eb2fa4
commit ff8856d2bc

View file

@ -26,8 +26,7 @@ def fftk(k_array):
def interpolate_power_spectrum(input, k, pk, sharding=None):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape)
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()