mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-14 09:51:11 +00:00
Allow env variable control of caching in growth
This commit is contained in:
parent
e7112e0c25
commit
41ae41ace3
1 changed files with 61 additions and 49 deletions
110
jaxpm/growth.py
110
jaxpm/growth.py
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.numpy import interp
|
||||
from jax_cosmo.background import *
|
||||
|
@ -243,56 +245,61 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
|
|||
Growth factor computed at requested scale factor
|
||||
"""
|
||||
# Check if growth has already been computed
|
||||
#if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
# Compute tabulated array
|
||||
atab = np.logspace(log10_amin, 0.0, steps)
|
||||
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
|
||||
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
|
||||
):
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
else:
|
||||
# Compute tabulated array
|
||||
atab = np.logspace(log10_amin, 0.0, steps)
|
||||
|
||||
def D_derivs(y, x):
|
||||
q = (2.0 - 0.5 *
|
||||
(Omega_m_a(cosmo, x) +
|
||||
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
|
||||
r = 1.5 * Omega_m_a(cosmo, x) / x / x
|
||||
def D_derivs(y, x):
|
||||
q = (2.0 - 0.5 *
|
||||
(Omega_m_a(cosmo, x) +
|
||||
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
|
||||
r = 1.5 * Omega_m_a(cosmo, x) / x / x
|
||||
|
||||
g1, g2 = y[0]
|
||||
f1, f2 = y[1]
|
||||
dy1da = [f1, -q * f1 + r * g1]
|
||||
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
|
||||
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
|
||||
g1, g2 = y[0]
|
||||
f1, f2 = y[1]
|
||||
dy1da = [f1, -q * f1 + r * g1]
|
||||
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
|
||||
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
|
||||
|
||||
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
|
||||
[1.0, -6.0 / 7 * atab[0]]])
|
||||
y = odeint(D_derivs, y0, atab)
|
||||
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
|
||||
[1.0, -6.0 / 7 * atab[0]]])
|
||||
y = odeint(D_derivs, y0, atab)
|
||||
|
||||
# compute second order derivatives growth
|
||||
dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab)
|
||||
dyda2 = np.transpose(dyda2, (2, 0, 1))
|
||||
# compute second order derivatives growth
|
||||
dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab)
|
||||
dyda2 = np.transpose(dyda2, (2, 0, 1))
|
||||
|
||||
# Normalize results
|
||||
y1 = y[:, 0, 0]
|
||||
gtab = y1 / y1[-1]
|
||||
y2 = y[:, 0, 1]
|
||||
g2tab = y2 / y2[-1]
|
||||
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
|
||||
ftab = y[:, 1, 0] / y1[-1] * atab / gtab
|
||||
f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab
|
||||
# Similarly for second order derivatives
|
||||
# Note: these factors are not accessible as parent functions yet
|
||||
# since it is unclear what to refer to them with.
|
||||
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
|
||||
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
|
||||
# Normalize results
|
||||
y1 = y[:, 0, 0]
|
||||
gtab = y1 / y1[-1]
|
||||
y2 = y[:, 0, 1]
|
||||
g2tab = y2 / y2[-1]
|
||||
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
|
||||
ftab = y[:, 1, 0] / y1[-1] * atab / gtab
|
||||
f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab
|
||||
# Similarly for second order derivatives
|
||||
# Note: these factors are not accessible as parent functions yet
|
||||
# since it is unclear what to refer to them with.
|
||||
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
|
||||
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
|
||||
|
||||
cache = {
|
||||
"a": atab,
|
||||
"g": gtab,
|
||||
"f": ftab,
|
||||
"h": htab,
|
||||
"g2": g2tab,
|
||||
"f2": f2tab,
|
||||
"h2": h2tab,
|
||||
}
|
||||
if CACHING_ACTIVATED:
|
||||
cosmo._workspace["background.growth_factor"] = cache
|
||||
|
||||
cache = {
|
||||
"a": atab,
|
||||
"g": gtab,
|
||||
"f": ftab,
|
||||
"h": htab,
|
||||
"g2": g2tab,
|
||||
"f2": f2tab,
|
||||
"h2": h2tab,
|
||||
}
|
||||
|
||||
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) , cache
|
||||
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0), cache
|
||||
|
||||
|
||||
def _growth_rate_ODE(cosmo, a):
|
||||
|
@ -313,10 +320,11 @@ def _growth_rate_ODE(cosmo, a):
|
|||
Growth rate computed at requested scale factor
|
||||
"""
|
||||
# Check if growth has already been computed, if not, compute it
|
||||
|
||||
|
||||
cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
|
||||
return interp(a, cache["a"], cache["f"])
|
||||
|
||||
|
||||
def _growth_factor_second_ODE(cosmo, a):
|
||||
"""Compute second order growth factor D2(a) at a given scale factor,
|
||||
normalised such that D(a=1) = 1.
|
||||
|
@ -384,7 +392,11 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
|
|||
|
||||
"""
|
||||
# Check if growth has already been computed, if not, compute it
|
||||
if not "background.growth_factor" in cosmo._workspace.keys():
|
||||
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
|
||||
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
|
||||
):
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
else:
|
||||
# Compute tabulated array
|
||||
atab = np.logspace(log10_amin, 0.0, steps)
|
||||
|
||||
|
@ -395,9 +407,8 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
|
|||
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
|
||||
gtab = gtab / gtab[-1] # Normalize to a=1.
|
||||
cache = {"a": atab, "g": gtab}
|
||||
cosmo._workspace["background.growth_factor"] = cache
|
||||
else:
|
||||
cache = cosmo._workspace["background.growth_factor"]
|
||||
if CACHING_ACTIVATED:
|
||||
cosmo._workspace["background.growth_factor"] = cache
|
||||
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
|
||||
|
||||
|
||||
|
@ -522,6 +533,7 @@ def gp(cosmo, a):
|
|||
D1f = f1 * g1 / a
|
||||
return D1f
|
||||
|
||||
|
||||
def dGfa(cosmo, a):
|
||||
r""" Derivative of Gf against a
|
||||
|
||||
|
@ -592,4 +604,4 @@ def dGf2a(cosmo, a):
|
|||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||
E_a = E(cosmo, a)
|
||||
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
||||
3 * a**2 * E_a * D2f)
|
||||
3 * a**2 * E_a * D2f)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue