mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
Wrap interpolation function to avoid all gather
This commit is contained in:
parent
7f48cfa8af
commit
c81d4d2336
1 changed files with 8 additions and 0 deletions
|
@ -1,6 +1,7 @@
|
|||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
import numpy as np
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
|
@ -32,6 +33,13 @@ def fftk(shape, dtype=np.float32):
|
|||
return kx, ky, kz
|
||||
|
||||
|
||||
def interpolate_power_spectrum(input, k, pk):
|
||||
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
|
||||
).reshape(x.shape)
|
||||
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'))(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
|
|
Loading…
Add table
Reference in a new issue