mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 10:00:54 +00:00
power spec should accept pytrees
This commit is contained in:
parent
38f6599974
commit
151fa09247
1 changed files with 3 additions and 3 deletions
|
@ -4,13 +4,13 @@ import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.scipy.stats import norm
|
from jax.scipy.stats import norm
|
||||||
from scipy.special import legendre
|
from scipy.special import legendre
|
||||||
|
import jax
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -100,11 +100,11 @@ def power_spectrum(mesh,
|
||||||
n_bins = len(kavg) + 2
|
n_bins = len(kavg) + 2
|
||||||
|
|
||||||
# FFTs
|
# FFTs
|
||||||
meshk = jnp.fft.fftn(mesh, norm='ortho')
|
meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh)
|
||||||
if mesh2 is None:
|
if mesh2 is None:
|
||||||
mmk = meshk.real**2 + meshk.imag**2
|
mmk = meshk.real**2 + meshk.imag**2
|
||||||
else:
|
else:
|
||||||
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
|
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2)
|
||||||
|
|
||||||
# Sum powers
|
# Sum powers
|
||||||
pk = jnp.empty((len(poles), n_bins))
|
pk = jnp.empty((len(poles), n_bins))
|
||||||
|
|
Loading…
Add table
Reference in a new issue