Allow env variable control of caching in growth

This commit is contained in:
Wassim Kabalan 2025-06-08 10:45:04 +02:00
parent e7112e0c25
commit 41ae41ace3

View file

@ -1,3 +1,5 @@
import os
import jax.numpy as np import jax.numpy as np
from jax.numpy import interp from jax.numpy import interp
from jax_cosmo.background import * from jax_cosmo.background import *
@ -243,7 +245,11 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
Growth factor computed at requested scale factor Growth factor computed at requested scale factor
""" """
# Check if growth has already been computed # Check if growth has already been computed
#if not "background.growth_factor" in cosmo._workspace.keys(): CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
):
cache = cosmo._workspace["background.growth_factor"]
else:
# Compute tabulated array # Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps) atab = np.logspace(log10_amin, 0.0, steps)
@ -281,7 +287,6 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
cache = { cache = {
"a": atab, "a": atab,
"g": gtab, "g": gtab,
@ -291,6 +296,8 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
"f2": f2tab, "f2": f2tab,
"h2": h2tab, "h2": h2tab,
} }
if CACHING_ACTIVATED:
cosmo._workspace["background.growth_factor"] = cache
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0), cache return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0), cache
@ -317,6 +324,7 @@ def _growth_rate_ODE(cosmo, a):
cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1] cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
return interp(a, cache["a"], cache["f"]) return interp(a, cache["a"], cache["f"])
def _growth_factor_second_ODE(cosmo, a): def _growth_factor_second_ODE(cosmo, a):
"""Compute second order growth factor D2(a) at a given scale factor, """Compute second order growth factor D2(a) at a given scale factor,
normalised such that D(a=1) = 1. normalised such that D(a=1) = 1.
@ -384,7 +392,11 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
""" """
# Check if growth has already been computed, if not, compute it # Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys(): CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
):
cache = cosmo._workspace["background.growth_factor"]
else:
# Compute tabulated array # Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps) atab = np.logspace(log10_amin, 0.0, steps)
@ -395,9 +407,8 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab))) gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
gtab = gtab / gtab[-1] # Normalize to a=1. gtab = gtab / gtab[-1] # Normalize to a=1.
cache = {"a": atab, "g": gtab} cache = {"a": atab, "g": gtab}
if CACHING_ACTIVATED:
cosmo._workspace["background.growth_factor"] = cache cosmo._workspace["background.growth_factor"] = cache
else:
cache = cosmo._workspace["background.growth_factor"]
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
@ -522,6 +533,7 @@ def gp(cosmo, a):
D1f = f1 * g1 / a D1f = f1 * g1 / a
return D1f return D1f
def dGfa(cosmo, a): def dGfa(cosmo, a):
r""" Derivative of Gf against a r""" Derivative of Gf against a