mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 09:37:11 +00:00
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
592 lines
16 KiB
Python
592 lines
16 KiB
Python
import jax.numpy as np
|
|
from jax.numpy import interp
|
|
from jax_cosmo.background import *
|
|
from jax_cosmo.scipy.ode import odeint
|
|
|
|
|
|
def E(cosmo, a):
|
|
r"""Scale factor dependent factor E(a) in the Hubble
|
|
parameter.
|
|
Parameters
|
|
----------
|
|
a : array_like
|
|
Scale factor
|
|
Returns
|
|
-------
|
|
E : ndarray, or float if input scalar
|
|
Square of the scaling of the Hubble constant as a function of
|
|
scale factor
|
|
Notes
|
|
-----
|
|
The Hubble parameter at scale factor `a` is given by
|
|
:math:`H^2(a) = E^2(a) H_o^2` where :math:`E^2` is obtained through
|
|
Friedman's Equation (see :cite:`2005:Percival`) :
|
|
.. math::
|
|
E^2(a) = \Omega_m a^{-3} + \Omega_k a^{-2} + \Omega_{de} a^{f(a)}
|
|
where :math:`f(a)` is the Dark Energy evolution parameter computed
|
|
by :py:meth:`.f_de`.
|
|
"""
|
|
return np.power(Esqr(cosmo, a), 0.5)
|
|
|
|
|
|
def df_de(cosmo, a, epsilon=1e-5):
|
|
r"""Derivative of the evolution parameter for the Dark Energy density
|
|
f(a) with respect to the scale factor.
|
|
Parameters
|
|
----------
|
|
cosmo: Cosmology
|
|
Cosmological parameters structure
|
|
a : array_like
|
|
Scale factor
|
|
epsilon: float value
|
|
Small number to make sure we are not dividing by 0 and avoid a singularity
|
|
Returns
|
|
-------
|
|
df(a)/da : ndarray, or float if input scalar
|
|
Derivative of the evolution parameter for the Dark Energy density
|
|
with respect to the scale factor.
|
|
Notes
|
|
-----
|
|
The expression for :math:`\frac{df(a)}{da}` is:
|
|
.. math::
|
|
\frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)-
|
|
\frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)}
|
|
"""
|
|
return (3 * cosmo.wa * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) /
|
|
np.power(np.log(a - epsilon), 2))
|
|
|
|
|
|
def dEa(cosmo, a):
|
|
r"""Derivative of the scale factor dependent factor E(a) in the Hubble
|
|
parameter.
|
|
Parameters
|
|
----------
|
|
a : array_like
|
|
Scale factor
|
|
Returns
|
|
-------
|
|
dE(a)/da : ndarray, or float if input scalar
|
|
Derivative of the scale factor dependent factor in the Hubble
|
|
parameter with respect to the scale factor.
|
|
Notes
|
|
-----
|
|
The expression for :math:`\frac{dE}{da}` is:
|
|
.. math::
|
|
\frac{dE(a)}{da}=\frac{-3a^{-4}\Omega_{0m}
|
|
-2a^{-3}\Omega_{0k}
|
|
+f'_{de}\Omega_{0de}a^{f_{de}(a)}}{2E(a)}
|
|
Notes
|
|
-----
|
|
The Hubble parameter at scale factor `a` is given by
|
|
:math:`H^2(a) = E^2(a) H_o^2` where :math:`E^2` is obtained through
|
|
Friedman's Equation (see :cite:`2005:Percival`) :
|
|
.. math::
|
|
E^2(a) = \Omega_m a^{-3} + \Omega_k a^{-2} + \Omega_{de} a^{f(a)}
|
|
where :math:`f(a)` is the Dark Energy evolution parameter computed
|
|
by :py:meth:`.f_de`.
|
|
"""
|
|
return (0.5 *
|
|
(-3 * cosmo.Omega_m * np.power(a, -4) -
|
|
2 * cosmo.Omega_k * np.power(a, -3) +
|
|
df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a))) /
|
|
np.power(Esqr(cosmo, a), 0.5))
|
|
|
|
|
|
def growth_factor(cosmo, a):
|
|
"""Compute linear growth factor D(a) at a given scale factor,
|
|
normalized such that D(a=1) = 1.
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: `Cosmology`
|
|
Cosmology object
|
|
|
|
a: array_like
|
|
Scale factor
|
|
|
|
Returns
|
|
-------
|
|
D: ndarray, or float if input scalar
|
|
Growth factor computed at requested scale factor
|
|
|
|
Notes
|
|
-----
|
|
The growth computation will depend on the cosmology parametrization, for
|
|
instance if the $\gamma$ parameter is defined, the growth will be computed
|
|
assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for
|
|
growth will be solved.
|
|
"""
|
|
if cosmo._flags["gamma_growth"]:
|
|
return _growth_factor_gamma(cosmo, a)
|
|
else:
|
|
return _growth_factor_ODE(cosmo, a)
|
|
|
|
|
|
def growth_factor_second(cosmo, a):
|
|
"""Compute second order growth factor D2(a) at a given scale factor,
|
|
normalized such that D(a=1) = 1.
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: `Cosmology`
|
|
Cosmology object
|
|
|
|
a: array_like
|
|
Scale factor
|
|
|
|
Returns
|
|
-------
|
|
D2: ndarray, or float if input scalar
|
|
Growth factor computed at requested scale factor
|
|
|
|
Notes
|
|
-----
|
|
The growth computation will depend on the cosmology parametrization,
|
|
as for the linear growth. Currently the second order growth
|
|
factor is not implemented with $\gamma$ parameter.
|
|
"""
|
|
if cosmo._flags["gamma_growth"]:
|
|
raise NotImplementedError(
|
|
"Gamma growth rate is not implemented for second order growth!")
|
|
return None
|
|
else:
|
|
return _growth_factor_second_ODE(cosmo, a)
|
|
|
|
|
|
def growth_rate(cosmo, a):
|
|
"""Compute growth rate dD/dlna at a given scale factor.
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: `Cosmology`
|
|
Cosmology object
|
|
|
|
a: array_like
|
|
Scale factor
|
|
|
|
Returns
|
|
-------
|
|
f: ndarray, or float if input scalar
|
|
Growth rate computed at requested scale factor
|
|
|
|
Notes
|
|
-----
|
|
The growth computation will depend on the cosmology parametrization, for
|
|
instance if the $\gamma$ parameter is defined, the growth will be computed
|
|
assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for
|
|
growth will be solved.
|
|
|
|
The LCDM approximation to the growth rate :math:`f_{\gamma}(a)` is given by:
|
|
|
|
.. math::
|
|
|
|
f_{\gamma}(a) = \Omega_m^{\gamma} (a)
|
|
|
|
with :math: `\gamma` in LCDM, given approximately by:
|
|
.. math::
|
|
|
|
\gamma = 0.55
|
|
|
|
see :cite:`2019:Euclid Preparation VII, eqn.32`
|
|
"""
|
|
if cosmo._flags["gamma_growth"]:
|
|
return _growth_rate_gamma(cosmo, a)
|
|
else:
|
|
return _growth_rate_ODE(cosmo, a)
|
|
|
|
|
|
def growth_rate_second(cosmo, a):
|
|
"""Compute second order growth rate dD2/dlna at a given scale factor.
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: `Cosmology`
|
|
Cosmology object
|
|
|
|
a: array_like
|
|
Scale factor
|
|
|
|
Returns
|
|
-------
|
|
f2: ndarray, or float if input scalar
|
|
Second order growth rate computed at requested scale factor
|
|
|
|
Notes
|
|
-----
|
|
The growth computation will depend on the cosmology parametrization,
|
|
as for the linear growth rate. Currently the second order growth
|
|
rate is not implemented with $\gamma$ parameter.
|
|
"""
|
|
if cosmo._flags["gamma_growth"]:
|
|
raise NotImplementedError(
|
|
"Gamma growth factor is not implemented for second order growth!")
|
|
return None
|
|
else:
|
|
return _growth_rate_second_ODE(cosmo, a)
|
|
|
|
|
|
def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4):
|
|
"""Compute linear growth factor D(a) at a given scale factor,
|
|
normalised such that D(a=1) = 1.
|
|
|
|
Parameters
|
|
----------
|
|
a: array_like
|
|
Scale factor
|
|
|
|
amin: float
|
|
Mininum scale factor, default 1e-3
|
|
|
|
Returns
|
|
-------
|
|
D: ndarray, or float if input scalar
|
|
Growth factor computed at requested scale factor
|
|
"""
|
|
# Check if growth has already been computed
|
|
if not "background.growth_factor" in cosmo._workspace.keys():
|
|
# Compute tabulated array
|
|
atab = np.logspace(log10_amin, 0.0, steps)
|
|
|
|
def D_derivs(y, x):
|
|
q = (2.0 - 0.5 *
|
|
(Omega_m_a(cosmo, x) +
|
|
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
|
|
r = 1.5 * Omega_m_a(cosmo, x) / x / x
|
|
|
|
g1, g2 = y[0]
|
|
f1, f2 = y[1]
|
|
dy1da = [f1, -q * f1 + r * g1]
|
|
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
|
|
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
|
|
|
|
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
|
|
[1.0, -6.0 / 7 * atab[0]]])
|
|
y = odeint(D_derivs, y0, atab)
|
|
|
|
# compute second order derivatives growth
|
|
dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab)
|
|
dyda2 = np.transpose(dyda2, (2, 0, 1))
|
|
|
|
# Normalize results
|
|
y1 = y[:, 0, 0]
|
|
gtab = y1 / y1[-1]
|
|
y2 = y[:, 0, 1]
|
|
g2tab = y2 / y2[-1]
|
|
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
|
|
ftab = y[:, 1, 0] / y1[-1] * atab / gtab
|
|
f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab
|
|
# Similarly for second order derivatives
|
|
# Note: these factors are not accessible as parent functions yet
|
|
# since it is unclear what to refer to them with.
|
|
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
|
|
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
|
|
|
|
cache = {
|
|
"a": atab,
|
|
"g": gtab,
|
|
"f": ftab,
|
|
"h": htab,
|
|
"g2": g2tab,
|
|
"f2": f2tab,
|
|
"h2": h2tab,
|
|
}
|
|
cosmo._workspace["background.growth_factor"] = cache
|
|
else:
|
|
cache = cosmo._workspace["background.growth_factor"]
|
|
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
|
|
|
|
|
|
def _growth_rate_ODE(cosmo, a):
|
|
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
|
|
growth ODE.
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: `Cosmology`
|
|
Cosmology object
|
|
|
|
a: array_like
|
|
Scale factor
|
|
|
|
Returns
|
|
-------
|
|
f: ndarray, or float if input scalar
|
|
Growth rate computed at requested scale factor
|
|
"""
|
|
# Check if growth has already been computed, if not, compute it
|
|
if not "background.growth_factor" in cosmo._workspace.keys():
|
|
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
|
cache = cosmo._workspace["background.growth_factor"]
|
|
return interp(a, cache["a"], cache["f"])
|
|
|
|
|
|
def _growth_factor_second_ODE(cosmo, a):
|
|
"""Compute second order growth factor D2(a) at a given scale factor,
|
|
normalised such that D(a=1) = 1.
|
|
|
|
Parameters
|
|
----------
|
|
a: array_like
|
|
Scale factor
|
|
|
|
amin: float
|
|
Mininum scale factor, default 1e-3
|
|
|
|
Returns
|
|
-------
|
|
D2: ndarray, or float if input scalar
|
|
Second order growth factor computed at requested scale factor
|
|
"""
|
|
# Check if growth has already been computed, if not, compute it
|
|
if not "background.growth_factor" in cosmo._workspace.keys():
|
|
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
|
cache = cosmo._workspace["background.growth_factor"]
|
|
return interp(a, cache["a"], cache["g2"])
|
|
|
|
|
|
def _growth_rate_ODE(cosmo, a):
|
|
"""Compute growth rate dD/dlna at a given scale factor by solving the linear
|
|
growth ODE.
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: `Cosmology`
|
|
Cosmology object
|
|
|
|
a: array_like
|
|
Scale factor
|
|
|
|
Returns
|
|
-------
|
|
f: ndarray, or float if input scalar
|
|
Second order growth rate computed at requested scale factor
|
|
"""
|
|
# Check if growth has already been computed, if not, compute it
|
|
if not "background.growth_factor" in cosmo._workspace.keys():
|
|
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
|
cache = cosmo._workspace["background.growth_factor"]
|
|
return interp(a, cache["a"], cache["f"])
|
|
|
|
|
|
def _growth_rate_second_ODE(cosmo, a):
|
|
"""Compute second order growth rate dD2/dlna at a given scale factor by solving the linear
|
|
growth ODE.
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: `Cosmology`
|
|
Cosmology object
|
|
|
|
a: array_like
|
|
Scale factor
|
|
|
|
Returns
|
|
-------
|
|
f2: ndarray, or float if input scalar
|
|
Second order growth rate computed at requested scale factor
|
|
"""
|
|
# Check if growth has already been computed, if not, compute it
|
|
if not "background.growth_factor" in cosmo._workspace.keys():
|
|
_growth_factor_ODE(cosmo, np.atleast_1d(1.0))
|
|
cache = cosmo._workspace["background.growth_factor"]
|
|
return interp(a, cache["a"], cache["f2"])
|
|
|
|
|
|
def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
|
|
r"""Computes growth factor by integrating the growth rate provided by the
|
|
\gamma parametrization. Normalized such that D( a=1) =1
|
|
|
|
Parameters
|
|
----------
|
|
a: array_like
|
|
Scale factor
|
|
|
|
amin: float
|
|
Mininum scale factor, default 1e-3
|
|
|
|
Returns
|
|
-------
|
|
D: ndarray, or float if input scalar
|
|
Growth factor computed at requested scale factor
|
|
|
|
"""
|
|
# Check if growth has already been computed, if not, compute it
|
|
if not "background.growth_factor" in cosmo._workspace.keys():
|
|
# Compute tabulated array
|
|
atab = np.logspace(log10_amin, 0.0, steps)
|
|
|
|
def integrand(y, loga):
|
|
xa = np.exp(loga)
|
|
return _growth_rate_gamma(cosmo, xa)
|
|
|
|
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
|
|
gtab = gtab / gtab[-1] # Normalize to a=1.
|
|
cache = {"a": atab, "g": gtab}
|
|
cosmo._workspace["background.growth_factor"] = cache
|
|
else:
|
|
cache = cosmo._workspace["background.growth_factor"]
|
|
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
|
|
|
|
|
|
def _growth_rate_gamma(cosmo, a):
|
|
r"""Growth rate approximation at scale factor `a`.
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: `Cosmology`
|
|
Cosmology object
|
|
|
|
a : array_like
|
|
Scale factor
|
|
|
|
Returns
|
|
-------
|
|
f_gamma : ndarray, or float if input scalar
|
|
Growth rate approximation at the requested scale factor
|
|
|
|
Notes
|
|
-----
|
|
The LCDM approximation to the growth rate :math:`f_{\gamma}(a)` is given by:
|
|
|
|
.. math::
|
|
|
|
f_{\gamma}(a) = \Omega_m^{\gamma} (a)
|
|
|
|
with :math: `\gamma` in LCDM, given approximately by:
|
|
.. math::
|
|
|
|
\gamma = 0.55
|
|
|
|
see :cite:`2019:Euclid Preparation VII, eqn.32`
|
|
"""
|
|
return Omega_m_a(cosmo, a)**cosmo.gamma
|
|
|
|
|
|
def Gf(cosmo, a):
|
|
r"""
|
|
FastPM growth factor function
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: dict
|
|
Cosmology dictionary.
|
|
|
|
a : array_like
|
|
Scale factor.
|
|
|
|
Returns
|
|
-------
|
|
Scalar float Tensor : FastPM growth factor function.
|
|
|
|
Notes
|
|
-----
|
|
|
|
The expression for :math:`Gf(a)` is:
|
|
|
|
.. math::
|
|
Gf(a)=D'_{1norm}*a**3*E(a)
|
|
"""
|
|
f1 = growth_rate(cosmo, a)
|
|
g1 = growth_factor(cosmo, a)
|
|
D1f = f1 * g1 / a
|
|
return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
|
|
|
|
|
def Gf2(cosmo, a):
|
|
r""" FastPM second order growth factor function
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: dict
|
|
Cosmology dictionary.
|
|
|
|
a : array_like
|
|
Scale factor.
|
|
|
|
Returns
|
|
-------
|
|
Scalar float Tensor : FastPM second order growth factor function.
|
|
|
|
Notes
|
|
-----
|
|
|
|
The expression for :math:`Gf_2(a)` is:
|
|
|
|
.. math::
|
|
Gf_2(a)=D'_{2norm}*a**3*E(a)
|
|
"""
|
|
f2 = growth_rate_second(cosmo, a)
|
|
g2 = growth_factor_second(cosmo, a)
|
|
D2f = f2 * g2 / a
|
|
return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5)
|
|
|
|
|
|
def dGfa(cosmo, a):
|
|
r""" Derivative of Gf against a
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: dict
|
|
Cosmology dictionary.
|
|
|
|
a : array_like
|
|
Scale factor.
|
|
|
|
Returns
|
|
-------
|
|
Scalar float Tensor : the derivative of Gf against a.
|
|
|
|
Notes
|
|
-----
|
|
|
|
The expression for :math:`gf(a)` is:
|
|
|
|
.. math::
|
|
gf(a)=\frac{dGF}{da}= D^{''}_1 * a ** 3 *E(a) +D'_{1norm}*a ** 3 * E'(a)
|
|
+ 3 * a ** 2 * E(a)*D'_{1norm}
|
|
|
|
"""
|
|
f1 = growth_rate(cosmo, a)
|
|
g1 = growth_factor(cosmo, a)
|
|
D1f = f1 * g1 / a
|
|
cache = cosmo._workspace['background.growth_factor']
|
|
f1p = cache['h'] / cache['a'] * cache['g']
|
|
f1p = interp(np.log(a), np.log(cache['a']), f1p)
|
|
Ea = E(cosmo, a)
|
|
return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + 3 * a**2 * Ea * D1f)
|
|
|
|
|
|
def dGf2a(cosmo, a):
|
|
r""" Derivative of Gf2 against a
|
|
|
|
Parameters
|
|
----------
|
|
cosmo: dict
|
|
Cosmology dictionary.
|
|
|
|
a : array_like
|
|
Scale factor.
|
|
|
|
Returns
|
|
-------
|
|
Scalar float Tensor : the derivative of Gf2 against a.
|
|
|
|
Notes
|
|
-----
|
|
|
|
The expression for :math:`gf2(a)` is:
|
|
|
|
.. math::
|
|
gf_2(a)=\frac{dGF_2}{da}= D^{''}_2 * a ** 3 *E(a) +D'_{2norm}*a ** 3 * E'(a)
|
|
+ 3 * a ** 2 * E(a)*D'_{2norm}
|
|
|
|
"""
|
|
f2 = growth_rate_second(cosmo, a)
|
|
g2 = growth_factor_second(cosmo, a)
|
|
D2f = f2 * g2 / a
|
|
cache = cosmo._workspace['background.growth_factor']
|
|
f2p = cache['h2'] / cache['a'] * cache['g2']
|
|
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
|
E_a = E(cosmo, a)
|
|
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
|
3 * a**2 * E_a * D2f)
|