diff --git a/jaxpm/ops.py b/jaxpm/ops.py index 3054807..70a8e4b 100644 --- a/jaxpm/ops.py +++ b/jaxpm/ops.py @@ -1,5 +1,8 @@ +import jax.numpy as jnp +import jax_cosmo as jc import numpy as np -from _src.spmd_config import pm_operators + +from jaxpm._src.spmd_config import pm_operators def fftn(arr): @@ -32,3 +35,7 @@ def fftk(shape, symmetric=True, finite=False, dtype=np.float32): def generate_initial_positions(shape): return pm_operators.generate_initial_positions(shape) + + +def interpolate_ic(kfield, kk, cosmo: jc.Cosmology, box_size): + return pm_operators.interpolate_ic(kfield, kk, cosmo, box_size)