mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
use jnp interp instead of jc interp
This commit is contained in:
parent
d2f1eb2fa4
commit
ff8856d2bc
1 changed files with 1 additions and 2 deletions
|
@ -26,8 +26,7 @@ def fftk(k_array):
|
|||
|
||||
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||
).reshape(x.shape)
|
||||
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
specs = sharding.spec if sharding is not None else P()
|
||||
|
|
Loading…
Add table
Reference in a new issue