Switch to Jeffrey

This commit is contained in:
rstiskalek 2024-09-25 16:47:27 +01:00
parent f7d16a9cd0
commit 498cca52b7
2 changed files with 65 additions and 27 deletions

View file

@ -31,7 +31,9 @@ from jax import jit
from jax import numpy as jnp from jax import numpy as jnp
from jax.scipy.special import erf, logsumexp from jax.scipy.special import erf, logsumexp
from numpyro import factor, plate, sample from numpyro import factor, plate, sample
from numpyro.distributions import MultivariateNormal, Normal, Uniform from numpyro.distributions import (Distribution, MultivariateNormal, Normal,
Uniform, constraints)
from numpyro.distributions.util import promote_shapes
from quadax import simpson from quadax import simpson
from tqdm import trange from tqdm import trange
@ -150,6 +152,35 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
return normal_logpdf(x, loc, scale) - jnp.log(norm) return normal_logpdf(x, loc, scale) - jnp.log(norm)
###############################################################################
# Inverse distribution #
###############################################################################
class InverseDistribution(Distribution):
"""Inverse distribution."""
support = constraints.positive
args_contraints = {
"low": constraints.positive,
"high": constraints.positive,
"fiducial_scale": constraints.positive,
}
reparametrized_params = ["low", "high"]
def __init__(self, low, high, fiducial_scale=1.):
self.low, self.high = promote_shapes(low, high)
self.log_fiducial_scale = jnp.log(fiducial_scale)
super(InverseDistribution, self).__init__(batch_shape=())
def sample(self, key, sample_shape=()):
z = Uniform(0, 1).sample(key, sample_shape)
return self.high * (self.low / self.high)**z
def log_prob(self, value):
return - (jnp.log(value) - self.log_fiducial_scale)
############################################################################### ###############################################################################
# Base flow validation # # Base flow validation #
############################################################################### ###############################################################################
@ -314,11 +345,13 @@ def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
+ e_mu_intrinsic**2) + e_mu_intrinsic**2)
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean, def sample_SN(mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal_std,
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
sample_alpha, name): sample_alpha, name):
"""Sample SNIe Tripp parameters.""" """Sample SNIe Tripp parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) # The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std)) mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
alpha_cal = sample( alpha_cal = sample(
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std)) f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
@ -350,15 +383,18 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic):
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2 return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
c_mean, c_std, alpha_min, alpha_max, sample_alpha, alpha_max, sample_alpha, a_dipole_mean, a_dipole_std,
a_dipole_mean, a_dipole_std, sample_a_dipole, name): sample_a_dipole, name):
"""Sample Tully-Fisher calibration parameters.""" """Sample Tully-Fisher calibration parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) # The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
a = sample(f"a_{name}", Normal(a_mean, a_std)) a = sample(f"a_{name}", Normal(a_mean, a_std))
if sample_a_dipole: if sample_a_dipole:
ax, ay, az = sample(f"a_dipole_{name}", Normal(a_dipole_mean, a_dipole_std).expand([3])) # noqa ax, ay, az = sample(f"a_dipole_{name}", Normal(
a_dipole_mean, a_dipole_std).expand([3]))
else: else:
ax, ay, az = 0.0, 0.0, 0.0 ax, ay, az = 0.0, 0.0, 0.0
@ -381,11 +417,13 @@ def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
# Simple calibration parameters sampling # # Simple calibration parameters sampling #
############################################################################### ###############################################################################
def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max, def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
dmu_dipole_mean, dmu_dipole_std, sample_alpha, dmu_dipole_std, sample_alpha, sample_dmu_dipole, name):
sample_dmu_dipole, name):
"""Sample simple calibration parameters.""" """Sample simple calibration parameters."""
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) # The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max)) dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha) alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
@ -409,11 +447,13 @@ def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
beta_max, sigma_v_min, sigma_v_max, h_min, h_max, beta_max, h_min, h_max, rLG_min, rLG_max, no_Vext,
rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta, sample_Vmono, sample_beta, sample_h, sample_rLG,
sample_h, sample_rLG, sample_Vmag_vax): sample_Vmag_vax):
"""Sample the flow calibration.""" """Sample the flow calibration."""
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max)) # The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
sigma_v = sample("sigma_v", InverseDistribution(100, 1000., 150.))
if sample_beta: if sample_beta:
beta = sample("beta", Uniform(beta_min, beta_max)) beta = sample("beta", Uniform(beta_min, beta_max))

View file

@ -190,6 +190,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder,
neg_ln_evidence = jax.numpy.nan neg_ln_evidence = jax.numpy.nan
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan) neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
# Temporarily disable saving.
return
fname = join(out_folder, fname) fname = join(out_folder, fname)
print(f"Saving results to `{fname}`.") print(f"Saving results to `{fname}`.")
with File(fname, "w") as f: with File(fname, "w") as f:
@ -236,16 +238,14 @@ def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
alpha_max = 10.0 alpha_max = 10.0
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
return {"e_mu_min": 0.001, "e_mu_max": 1.0, return {"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0, "alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
"beta_cal_mean": 3.112, "beta_cal_std": 2.0, "beta_cal_mean": 3.112, "beta_cal_std": 2.0,
"alpha_min": alpha_min, "alpha_max": alpha_max, "alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha "sample_alpha": sample_alpha
} }
elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue: elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue:
return {"e_mu_min": 0.001, "e_mu_max": 1.0, return {"a_mean": -21., "a_std": 5.0,
"a_mean": -21., "a_std": 5.0,
"b_mean": -5.95, "b_std": 4.0, "b_mean": -5.95, "b_std": 4.0,
"c_mean": 0., "c_std": 20.0, "c_mean": 0., "c_std": 20.0,
"a_dipole_mean": 0., "a_dipole_std": 1.0, "a_dipole_mean": 0., "a_dipole_std": 1.0,
@ -254,8 +254,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
"sample_alpha": sample_alpha, "sample_alpha": sample_alpha,
} }
elif catalogue in ["CF4_GroupAll"]: elif catalogue in ["CF4_GroupAll"]:
return {"e_mu_min": 0.001, "e_mu_max": 1.0, return {"dmu_min": -3.0, "dmu_max": 3.0,
"dmu_min": -3.0, "dmu_max": 3.0,
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0, "dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
"sample_dmu_dipole": sample_mag_dipole, "sample_dmu_dipole": sample_mag_dipole,
"alpha_min": alpha_min, "alpha_max": alpha_max, "alpha_min": alpha_min, "alpha_max": alpha_max,
@ -299,16 +298,16 @@ if __name__ == "__main__":
########################################################################### ###########################################################################
# `None` means default behaviour # `None` means default behaviour
nsteps = 10_000 nsteps = 2_000
nburn = 2_000 nburn = 2_000
zcmb_min = None zcmb_min = None
zcmb_max = 0.05 zcmb_max = 0.05
nchains_harmonic = 10 nchains_harmonic = 10
num_epochs = 50 num_epochs = 50
inference_method = "mike" inference_method = "bayes"
mag_selection = None mag_selection = None
sample_alpha = False if "IndranilVoid_" in ARGS.simname or ARGS.simname == "no_field" else True # noqa sample_alpha = False if "IndranilVoid_" in ARGS.simname or ARGS.simname == "no_field" else True # noqa
sample_beta = None sample_beta = True
no_Vext = None no_Vext = None
sample_Vmag_vax = False sample_Vmag_vax = False
sample_Vmono = False sample_Vmono = False
@ -377,7 +376,6 @@ if __name__ == "__main__":
calibration_hyperparams = {"Vext_min": -3000, "Vext_max": 3000, calibration_hyperparams = {"Vext_min": -3000, "Vext_max": 3000,
"Vmono_min": -1000, "Vmono_max": 1000, "Vmono_min": -1000, "Vmono_max": 1000,
"beta_min": -10.0, "beta_max": 10.0, "beta_min": -10.0, "beta_max": 10.0,
"sigma_v_min": 1.0, "sigma_v_max": 5000 if "IndranilVoid_" in ARGS.simname else 750., # noqa
"h_min": 0.01, "h_max": 5.0, "h_min": 0.01, "h_max": 5.0,
"no_Vext": False if no_Vext is None else no_Vext, # noqa "no_Vext": False if no_Vext is None else no_Vext, # noqa
"sample_Vmag_vax": sample_Vmag_vax, "sample_Vmag_vax": sample_Vmag_vax,