Switch to Jeffrey

This commit is contained in:
rstiskalek 2024-09-25 16:47:27 +01:00
parent f7d16a9cd0
commit 498cca52b7
2 changed files with 65 additions and 27 deletions

View file

@ -31,7 +31,9 @@ from jax import jit
from jax import numpy as jnp
from jax.scipy.special import erf, logsumexp
from numpyro import factor, plate, sample
from numpyro.distributions import MultivariateNormal, Normal, Uniform
from numpyro.distributions import (Distribution, MultivariateNormal, Normal,
Uniform, constraints)
from numpyro.distributions.util import promote_shapes
from quadax import simpson
from tqdm import trange
@ -150,6 +152,35 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
return normal_logpdf(x, loc, scale) - jnp.log(norm)
###############################################################################
# Inverse distribution #
###############################################################################
class InverseDistribution(Distribution):
"""Inverse distribution."""
support = constraints.positive
args_contraints = {
"low": constraints.positive,
"high": constraints.positive,
"fiducial_scale": constraints.positive,
}
reparametrized_params = ["low", "high"]
def __init__(self, low, high, fiducial_scale=1.):
self.low, self.high = promote_shapes(low, high)
self.log_fiducial_scale = jnp.log(fiducial_scale)
super(InverseDistribution, self).__init__(batch_shape=())
def sample(self, key, sample_shape=()):
z = Uniform(0, 1).sample(key, sample_shape)
return self.high * (self.low / self.high)**z
def log_prob(self, value):
return - (jnp.log(value) - self.log_fiducial_scale)
###############################################################################
# Base flow validation #
###############################################################################
@ -314,11 +345,13 @@ def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
+ e_mu_intrinsic**2)
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
def sample_SN(mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal_std,
beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
sample_alpha, name):
"""Sample SNIe Tripp parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
# The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
alpha_cal = sample(
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
@ -350,15 +383,18 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic):
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
c_mean, c_std, alpha_min, alpha_max, sample_alpha,
a_dipole_mean, a_dipole_std, sample_a_dipole, name):
def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
alpha_max, sample_alpha, a_dipole_mean, a_dipole_std,
sample_a_dipole, name):
"""Sample Tully-Fisher calibration parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
# The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
a = sample(f"a_{name}", Normal(a_mean, a_std))
if sample_a_dipole:
ax, ay, az = sample(f"a_dipole_{name}", Normal(a_dipole_mean, a_dipole_std).expand([3])) # noqa
ax, ay, az = sample(f"a_dipole_{name}", Normal(
a_dipole_mean, a_dipole_std).expand([3]))
else:
ax, ay, az = 0.0, 0.0, 0.0
@ -381,11 +417,13 @@ def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
# Simple calibration parameters sampling #
###############################################################################
def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
dmu_dipole_mean, dmu_dipole_std, sample_alpha,
sample_dmu_dipole, name):
def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
dmu_dipole_std, sample_alpha, sample_dmu_dipole, name):
"""Sample simple calibration parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
# The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
@ -409,11 +447,13 @@ def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
beta_max, sigma_v_min, sigma_v_max, h_min, h_max,
rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta,
sample_h, sample_rLG, sample_Vmag_vax):
beta_max, h_min, h_max, rLG_min, rLG_max, no_Vext,
sample_Vmono, sample_beta, sample_h, sample_rLG,
sample_Vmag_vax):
"""Sample the flow calibration."""
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
# The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
sigma_v = sample("sigma_v", InverseDistribution(100, 1000., 150.))
if sample_beta:
beta = sample("beta", Uniform(beta_min, beta_max))

View file

@ -190,6 +190,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder,
neg_ln_evidence = jax.numpy.nan
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
# Temporarily disable saving.
return
fname = join(out_folder, fname)
print(f"Saving results to `{fname}`.")
with File(fname, "w") as f:
@ -236,16 +238,14 @@ def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
alpha_max = 10.0
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
return {"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha
}
elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue:
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"a_mean": -21., "a_std": 5.0,
return {"a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 4.0,
"c_mean": 0., "c_std": 20.0,
"a_dipole_mean": 0., "a_dipole_std": 1.0,
@ -254,8 +254,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
"sample_alpha": sample_alpha,
}
elif catalogue in ["CF4_GroupAll"]:
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"dmu_min": -3.0, "dmu_max": 3.0,
return {"dmu_min": -3.0, "dmu_max": 3.0,
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
"sample_dmu_dipole": sample_mag_dipole,
"alpha_min": alpha_min, "alpha_max": alpha_max,
@ -299,16 +298,16 @@ if __name__ == "__main__":
###########################################################################
# `None` means default behaviour
nsteps = 10_000
nsteps = 2_000
nburn = 2_000
zcmb_min = None
zcmb_max = 0.05
nchains_harmonic = 10
num_epochs = 50
inference_method = "mike"
inference_method = "bayes"
mag_selection = None
sample_alpha = False if "IndranilVoid_" in ARGS.simname or ARGS.simname == "no_field" else True # noqa
sample_beta = None
sample_beta = True
no_Vext = None
sample_Vmag_vax = False
sample_Vmono = False
@ -377,7 +376,6 @@ if __name__ == "__main__":
calibration_hyperparams = {"Vext_min": -3000, "Vext_max": 3000,
"Vmono_min": -1000, "Vmono_max": 1000,
"beta_min": -10.0, "beta_max": 10.0,
"sigma_v_min": 1.0, "sigma_v_max": 5000 if "IndranilVoid_" in ARGS.simname else 750., # noqa
"h_min": 0.01, "h_max": 5.0,
"no_Vext": False if no_Vext is None else no_Vext, # noqa
"sample_Vmag_vax": sample_Vmag_vax,