power spec should accept pytrees

This commit is contained in:
Wassim Kabalan 2025-01-20 22:40:01 +01:00
parent 38f6599974
commit 151fa09247

View file

@ -4,13 +4,13 @@ import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
from scipy.special import legendre
import jax
__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'
]
def _initialize_pk(mesh_shape, box_shape, kedges, los):
"""
Parameters
@ -100,11 +100,11 @@ def power_spectrum(mesh,
n_bins = len(kavg) + 2
# FFTs
meshk = jnp.fft.fftn(mesh, norm='ortho')
meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh)
if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2)
# Sum powers
pk = jnp.empty((len(poles), n_bins))