Wrap interpolation function to avoid all gather

This commit is contained in:
Wassim KABALAN 2024-07-18 12:43:52 +02:00
parent 7f48cfa8af
commit c81d4d2336

View file

@ -1,6 +1,7 @@
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax.sharding import PartitionSpec as P
@ -32,6 +33,13 @@ def fftk(shape, dtype=np.float32):
return kx, ky, kz
def interpolate_power_spectrum(input, k, pk):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape)
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'))(input)
def gradient_kernel(kvec, direction, order=1):
"""
Computes the gradient kernel in the requested direction