diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 96faeea..ab34da2 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -4,13 +4,13 @@ import jax.numpy as jnp import numpy as np from jax.scipy.stats import norm from scipy.special import legendre +import jax __all__ = [ 'power_spectrum', 'transfer', 'coherence', 'pktranscoh', 'cross_correlation_coefficients', 'gaussian_smoothing' ] - def _initialize_pk(mesh_shape, box_shape, kedges, los): """ Parameters @@ -100,11 +100,11 @@ def power_spectrum(mesh, n_bins = len(kavg) + 2 # FFTs - meshk = jnp.fft.fftn(mesh, norm='ortho') + meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh) if mesh2 is None: mmk = meshk.real**2 + meshk.imag**2 else: - mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj() + mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2) # Sum powers pk = jnp.empty((len(poles), n_bins))