mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-07-03 13:01:11 +00:00
Add Jeffrey priors
This commit is contained in:
parent
498cca52b7
commit
53b48363a1
1 changed files with 29 additions and 58 deletions
|
@ -31,9 +31,7 @@ from jax import jit
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax.scipy.special import erf, logsumexp
|
from jax.scipy.special import erf, logsumexp
|
||||||
from numpyro import factor, plate, sample
|
from numpyro import factor, plate, sample
|
||||||
from numpyro.distributions import (Distribution, MultivariateNormal, Normal,
|
from numpyro.distributions import MultivariateNormal, Normal, Uniform
|
||||||
Uniform, constraints)
|
|
||||||
from numpyro.distributions.util import promote_shapes
|
|
||||||
from quadax import simpson
|
from quadax import simpson
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
@ -152,35 +150,6 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
|
||||||
return normal_logpdf(x, loc, scale) - jnp.log(norm)
|
return normal_logpdf(x, loc, scale) - jnp.log(norm)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
# Inverse distribution #
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
|
|
||||||
class InverseDistribution(Distribution):
|
|
||||||
"""Inverse distribution."""
|
|
||||||
support = constraints.positive
|
|
||||||
args_contraints = {
|
|
||||||
"low": constraints.positive,
|
|
||||||
"high": constraints.positive,
|
|
||||||
"fiducial_scale": constraints.positive,
|
|
||||||
}
|
|
||||||
reparametrized_params = ["low", "high"]
|
|
||||||
|
|
||||||
def __init__(self, low, high, fiducial_scale=1.):
|
|
||||||
self.low, self.high = promote_shapes(low, high)
|
|
||||||
self.log_fiducial_scale = jnp.log(fiducial_scale)
|
|
||||||
|
|
||||||
super(InverseDistribution, self).__init__(batch_shape=())
|
|
||||||
|
|
||||||
def sample(self, key, sample_shape=()):
|
|
||||||
z = Uniform(0, 1).sample(key, sample_shape)
|
|
||||||
return self.high * (self.low / self.high)**z
|
|
||||||
|
|
||||||
def log_prob(self, value):
|
|
||||||
return - (jnp.log(value) - self.log_fiducial_scale)
|
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Base flow validation #
|
# Base flow validation #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -345,13 +314,11 @@ def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
|
||||||
+ e_mu_intrinsic**2)
|
+ e_mu_intrinsic**2)
|
||||||
|
|
||||||
|
|
||||||
def sample_SN(mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal_std,
|
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
|
||||||
beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
|
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
|
||||||
sample_alpha, name):
|
sample_alpha, name):
|
||||||
"""Sample SNIe Tripp parameters."""
|
"""Sample SNIe Tripp parameters."""
|
||||||
# The low/high are used only to generate an initial guess, they only
|
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||||
# have to be in the right order of magnitude so that's why hardcoded.
|
|
||||||
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
|
||||||
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
|
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
|
||||||
alpha_cal = sample(
|
alpha_cal = sample(
|
||||||
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
|
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
|
||||||
|
@ -383,18 +350,15 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic):
|
||||||
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
|
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
|
||||||
|
|
||||||
|
|
||||||
def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
|
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
|
||||||
alpha_max, sample_alpha, a_dipole_mean, a_dipole_std,
|
c_mean, c_std, alpha_min, alpha_max, sample_alpha,
|
||||||
sample_a_dipole, name):
|
a_dipole_mean, a_dipole_std, sample_a_dipole, name):
|
||||||
"""Sample Tully-Fisher calibration parameters."""
|
"""Sample Tully-Fisher calibration parameters."""
|
||||||
# The low/high are used only to generate an initial guess, they only
|
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||||
# have to be in the right order of magnitude so that's why hardcoded.
|
|
||||||
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
|
||||||
a = sample(f"a_{name}", Normal(a_mean, a_std))
|
a = sample(f"a_{name}", Normal(a_mean, a_std))
|
||||||
|
|
||||||
if sample_a_dipole:
|
if sample_a_dipole:
|
||||||
ax, ay, az = sample(f"a_dipole_{name}", Normal(
|
ax, ay, az = sample(f"a_dipole_{name}", Normal(a_dipole_mean, a_dipole_std).expand([3])) # noqa
|
||||||
a_dipole_mean, a_dipole_std).expand([3]))
|
|
||||||
else:
|
else:
|
||||||
ax, ay, az = 0.0, 0.0, 0.0
|
ax, ay, az = 0.0, 0.0, 0.0
|
||||||
|
|
||||||
|
@ -417,13 +381,11 @@ def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
|
||||||
# Simple calibration parameters sampling #
|
# Simple calibration parameters sampling #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
|
def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
|
||||||
dmu_dipole_std, sample_alpha, sample_dmu_dipole, name):
|
dmu_dipole_mean, dmu_dipole_std, sample_alpha,
|
||||||
|
sample_dmu_dipole, name):
|
||||||
"""Sample simple calibration parameters."""
|
"""Sample simple calibration parameters."""
|
||||||
# The low/high are used only to generate an initial guess, they only
|
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||||
# have to be in the right order of magnitude so that's why hardcoded.
|
|
||||||
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
|
||||||
|
|
||||||
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
|
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
|
||||||
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
|
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
|
||||||
|
|
||||||
|
@ -447,13 +409,11 @@ def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
|
||||||
|
|
||||||
|
|
||||||
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
|
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
|
||||||
beta_max, h_min, h_max, rLG_min, rLG_max, no_Vext,
|
beta_max, sigma_v_min, sigma_v_max, h_min, h_max,
|
||||||
sample_Vmono, sample_beta, sample_h, sample_rLG,
|
rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta,
|
||||||
sample_Vmag_vax):
|
sample_h, sample_rLG, sample_Vmag_vax):
|
||||||
"""Sample the flow calibration."""
|
"""Sample the flow calibration."""
|
||||||
# The low/high are used only to generate an initial guess, they only
|
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
|
||||||
# have to be in the right order of magnitude so that's why hardcoded.
|
|
||||||
sigma_v = sample("sigma_v", InverseDistribution(100, 1000., 150.))
|
|
||||||
|
|
||||||
if sample_beta:
|
if sample_beta:
|
||||||
beta = sample("beta", Uniform(beta_min, beta_max))
|
beta = sample("beta", Uniform(beta_min, beta_max))
|
||||||
|
@ -647,10 +607,15 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
Vmono = field_calibration_params["Vmono"]
|
Vmono = field_calibration_params["Vmono"]
|
||||||
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
|
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
|
||||||
|
|
||||||
|
e_mu = distmod_params["e_mu"]
|
||||||
|
|
||||||
|
# Jeffrey's prior on sigma_v and the intrinsic scatter, they are above
|
||||||
|
# "sampled" from uniform distributions.
|
||||||
|
ll0 -= jnp.log(sigma_v) + jnp.log(e_mu)
|
||||||
|
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
# 1. Sample true observables and obtain the distance estimate
|
# 1. Sample true observables and obtain the distance estimate
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
e_mu = distmod_params["e_mu"]
|
|
||||||
if self.kind == "SN":
|
if self.kind == "SN":
|
||||||
mag_cal = distmod_params["mag_cal"]
|
mag_cal = distmod_params["mag_cal"]
|
||||||
alpha_cal = distmod_params["alpha_cal"]
|
alpha_cal = distmod_params["alpha_cal"]
|
||||||
|
@ -664,6 +629,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
c_mean, c_std = sample_gaussian_hyperprior(
|
c_mean, c_std = sample_gaussian_hyperprior(
|
||||||
"c", self.name, self.c_min, self.c_max)
|
"c", self.name, self.c_min, self.c_max)
|
||||||
|
|
||||||
|
# Jeffrey's prior on the the MNR hyperprior widths.
|
||||||
|
ll0 -= jnp.log(mag_std) + jnp.log(x1_std) + jnp.log(c_std)
|
||||||
|
|
||||||
# NOTE: that the true variables are currently uncorrelated.
|
# NOTE: that the true variables are currently uncorrelated.
|
||||||
with plate(f"true_SN_{self.name}", self.ndata):
|
with plate(f"true_SN_{self.name}", self.ndata):
|
||||||
mag_true = sample(
|
mag_true = sample(
|
||||||
|
@ -715,6 +683,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
corr_mag_eta = sample(
|
corr_mag_eta = sample(
|
||||||
f"corr_mag_eta_{self.name}", Uniform(-1, 1))
|
f"corr_mag_eta_{self.name}", Uniform(-1, 1))
|
||||||
|
|
||||||
|
# Jeffrey's prior on the the MNR hyperprior widths.
|
||||||
|
ll0 -= jnp.log(mag_std) + jnp.log(eta_std)
|
||||||
|
|
||||||
loc = jnp.array([mag_mean, eta_mean])
|
loc = jnp.array([mag_mean, eta_mean])
|
||||||
cov = jnp.array(
|
cov = jnp.array(
|
||||||
[[mag_std**2, corr_mag_eta * mag_std * eta_std],
|
[[mag_std**2, corr_mag_eta * mag_std * eta_std],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue