Add Jeffrey priors

This commit is contained in:
rstiskalek 2024-09-26 10:52:37 +01:00
parent 498cca52b7
commit 53b48363a1

View file

@ -31,9 +31,7 @@ from jax import jit
from jax import numpy as jnp
from jax.scipy.special import erf, logsumexp
from numpyro import factor, plate, sample
from numpyro.distributions import (Distribution, MultivariateNormal, Normal,
Uniform, constraints)
from numpyro.distributions.util import promote_shapes
from numpyro.distributions import MultivariateNormal, Normal, Uniform
from quadax import simpson
from tqdm import trange
@ -152,35 +150,6 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
return normal_logpdf(x, loc, scale) - jnp.log(norm)
###############################################################################
# Inverse distribution #
###############################################################################
class InverseDistribution(Distribution):
"""Inverse distribution."""
support = constraints.positive
args_contraints = {
"low": constraints.positive,
"high": constraints.positive,
"fiducial_scale": constraints.positive,
}
reparametrized_params = ["low", "high"]
def __init__(self, low, high, fiducial_scale=1.):
self.low, self.high = promote_shapes(low, high)
self.log_fiducial_scale = jnp.log(fiducial_scale)
super(InverseDistribution, self).__init__(batch_shape=())
def sample(self, key, sample_shape=()):
z = Uniform(0, 1).sample(key, sample_shape)
return self.high * (self.low / self.high)**z
def log_prob(self, value):
return - (jnp.log(value) - self.log_fiducial_scale)
###############################################################################
# Base flow validation #
###############################################################################
@ -345,13 +314,11 @@ def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
+ e_mu_intrinsic**2)
def sample_SN(mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal_std,
beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
sample_alpha, name):
"""Sample SNIe Tripp parameters."""
# The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
alpha_cal = sample(
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
@ -383,18 +350,15 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic):
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
alpha_max, sample_alpha, a_dipole_mean, a_dipole_std,
sample_a_dipole, name):
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
c_mean, c_std, alpha_min, alpha_max, sample_alpha,
a_dipole_mean, a_dipole_std, sample_a_dipole, name):
"""Sample Tully-Fisher calibration parameters."""
# The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
a = sample(f"a_{name}", Normal(a_mean, a_std))
if sample_a_dipole:
ax, ay, az = sample(f"a_dipole_{name}", Normal(
a_dipole_mean, a_dipole_std).expand([3]))
ax, ay, az = sample(f"a_dipole_{name}", Normal(a_dipole_mean, a_dipole_std).expand([3])) # noqa
else:
ax, ay, az = 0.0, 0.0, 0.0
@ -417,13 +381,11 @@ def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
# Simple calibration parameters sampling #
###############################################################################
def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
dmu_dipole_std, sample_alpha, sample_dmu_dipole, name):
def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
dmu_dipole_mean, dmu_dipole_std, sample_alpha,
sample_dmu_dipole, name):
"""Sample simple calibration parameters."""
# The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
@ -447,13 +409,11 @@ def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
beta_max, h_min, h_max, rLG_min, rLG_max, no_Vext,
sample_Vmono, sample_beta, sample_h, sample_rLG,
sample_Vmag_vax):
beta_max, sigma_v_min, sigma_v_max, h_min, h_max,
rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta,
sample_h, sample_rLG, sample_Vmag_vax):
"""Sample the flow calibration."""
# The low/high are used only to generate an initial guess, they only
# have to be in the right order of magnitude so that's why hardcoded.
sigma_v = sample("sigma_v", InverseDistribution(100, 1000., 150.))
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
if sample_beta:
beta = sample("beta", Uniform(beta_min, beta_max))
@ -647,10 +607,15 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Vmono = field_calibration_params["Vmono"]
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
e_mu = distmod_params["e_mu"]
# Jeffrey's prior on sigma_v and the intrinsic scatter, they are above
# "sampled" from uniform distributions.
ll0 -= jnp.log(sigma_v) + jnp.log(e_mu)
# ------------------------------------------------------------
# 1. Sample true observables and obtain the distance estimate
# ------------------------------------------------------------
e_mu = distmod_params["e_mu"]
if self.kind == "SN":
mag_cal = distmod_params["mag_cal"]
alpha_cal = distmod_params["alpha_cal"]
@ -664,6 +629,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
c_mean, c_std = sample_gaussian_hyperprior(
"c", self.name, self.c_min, self.c_max)
# Jeffrey's prior on the the MNR hyperprior widths.
ll0 -= jnp.log(mag_std) + jnp.log(x1_std) + jnp.log(c_std)
# NOTE: that the true variables are currently uncorrelated.
with plate(f"true_SN_{self.name}", self.ndata):
mag_true = sample(
@ -715,6 +683,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
corr_mag_eta = sample(
f"corr_mag_eta_{self.name}", Uniform(-1, 1))
# Jeffrey's prior on the the MNR hyperprior widths.
ll0 -= jnp.log(mag_std) + jnp.log(eta_std)
loc = jnp.array([mag_mean, eta_mean])
cov = jnp.array(
[[mag_std**2, corr_mag_eta * mag_std * eta_std],