mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 01:57:10 +00:00
power spec should accept pytrees
This commit is contained in:
parent
38f6599974
commit
151fa09247
1 changed files with 3 additions and 3 deletions
|
@ -4,13 +4,13 @@ import jax.numpy as jnp
|
|||
import numpy as np
|
||||
from jax.scipy.stats import norm
|
||||
from scipy.special import legendre
|
||||
import jax
|
||||
|
||||
__all__ = [
|
||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||
]
|
||||
|
||||
|
||||
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||
"""
|
||||
Parameters
|
||||
|
@ -100,11 +100,11 @@ def power_spectrum(mesh,
|
|||
n_bins = len(kavg) + 2
|
||||
|
||||
# FFTs
|
||||
meshk = jnp.fft.fftn(mesh, norm='ortho')
|
||||
meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh)
|
||||
if mesh2 is None:
|
||||
mmk = meshk.real**2 + meshk.imag**2
|
||||
else:
|
||||
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
|
||||
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2)
|
||||
|
||||
# Sum powers
|
||||
pk = jnp.empty((len(poles), n_bins))
|
||||
|
|
Loading…
Add table
Reference in a new issue