From 151fa09247d41cb6010d66d08b9075b669b511b6 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Mon, 20 Jan 2025 22:40:01 +0100 Subject: [PATCH] power spec should accept pytrees --- jaxpm/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 96faeea..ab34da2 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -4,13 +4,13 @@ import jax.numpy as jnp import numpy as np from jax.scipy.stats import norm from scipy.special import legendre +import jax __all__ = [ 'power_spectrum', 'transfer', 'coherence', 'pktranscoh', 'cross_correlation_coefficients', 'gaussian_smoothing' ] - def _initialize_pk(mesh_shape, box_shape, kedges, los): """ Parameters @@ -100,11 +100,11 @@ def power_spectrum(mesh, n_bins = len(kavg) + 2 # FFTs - meshk = jnp.fft.fftn(mesh, norm='ortho') + meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh) if mesh2 is None: mmk = meshk.real**2 + meshk.imag**2 else: - mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj() + mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2) # Sum powers pk = jnp.empty((len(poles), n_bins))