mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2025-04-19 08:20:53 +00:00
Add Jeffrey priors
This commit is contained in:
parent
498cca52b7
commit
53b48363a1
1 changed files with 29 additions and 58 deletions
|
@ -31,9 +31,7 @@ from jax import jit
|
|||
from jax import numpy as jnp
|
||||
from jax.scipy.special import erf, logsumexp
|
||||
from numpyro import factor, plate, sample
|
||||
from numpyro.distributions import (Distribution, MultivariateNormal, Normal,
|
||||
Uniform, constraints)
|
||||
from numpyro.distributions.util import promote_shapes
|
||||
from numpyro.distributions import MultivariateNormal, Normal, Uniform
|
||||
from quadax import simpson
|
||||
from tqdm import trange
|
||||
|
||||
|
@ -152,35 +150,6 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
|
|||
return normal_logpdf(x, loc, scale) - jnp.log(norm)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Inverse distribution #
|
||||
###############################################################################
|
||||
|
||||
|
||||
class InverseDistribution(Distribution):
|
||||
"""Inverse distribution."""
|
||||
support = constraints.positive
|
||||
args_contraints = {
|
||||
"low": constraints.positive,
|
||||
"high": constraints.positive,
|
||||
"fiducial_scale": constraints.positive,
|
||||
}
|
||||
reparametrized_params = ["low", "high"]
|
||||
|
||||
def __init__(self, low, high, fiducial_scale=1.):
|
||||
self.low, self.high = promote_shapes(low, high)
|
||||
self.log_fiducial_scale = jnp.log(fiducial_scale)
|
||||
|
||||
super(InverseDistribution, self).__init__(batch_shape=())
|
||||
|
||||
def sample(self, key, sample_shape=()):
|
||||
z = Uniform(0, 1).sample(key, sample_shape)
|
||||
return self.high * (self.low / self.high)**z
|
||||
|
||||
def log_prob(self, value):
|
||||
return - (jnp.log(value) - self.log_fiducial_scale)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Base flow validation #
|
||||
###############################################################################
|
||||
|
@ -345,13 +314,11 @@ def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
|
|||
+ e_mu_intrinsic**2)
|
||||
|
||||
|
||||
def sample_SN(mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal_std,
|
||||
beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
|
||||
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
|
||||
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
|
||||
sample_alpha, name):
|
||||
"""Sample SNIe Tripp parameters."""
|
||||
# The low/high are used only to generate an initial guess, they only
|
||||
# have to be in the right order of magnitude so that's why hardcoded.
|
||||
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
||||
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
|
||||
alpha_cal = sample(
|
||||
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
|
||||
|
@ -383,18 +350,15 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic):
|
|||
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
|
||||
|
||||
|
||||
def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
|
||||
alpha_max, sample_alpha, a_dipole_mean, a_dipole_std,
|
||||
sample_a_dipole, name):
|
||||
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
|
||||
c_mean, c_std, alpha_min, alpha_max, sample_alpha,
|
||||
a_dipole_mean, a_dipole_std, sample_a_dipole, name):
|
||||
"""Sample Tully-Fisher calibration parameters."""
|
||||
# The low/high are used only to generate an initial guess, they only
|
||||
# have to be in the right order of magnitude so that's why hardcoded.
|
||||
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
||||
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||
a = sample(f"a_{name}", Normal(a_mean, a_std))
|
||||
|
||||
if sample_a_dipole:
|
||||
ax, ay, az = sample(f"a_dipole_{name}", Normal(
|
||||
a_dipole_mean, a_dipole_std).expand([3]))
|
||||
ax, ay, az = sample(f"a_dipole_{name}", Normal(a_dipole_mean, a_dipole_std).expand([3])) # noqa
|
||||
else:
|
||||
ax, ay, az = 0.0, 0.0, 0.0
|
||||
|
||||
|
@ -417,13 +381,11 @@ def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
|
|||
# Simple calibration parameters sampling #
|
||||
###############################################################################
|
||||
|
||||
def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
|
||||
dmu_dipole_std, sample_alpha, sample_dmu_dipole, name):
|
||||
def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
|
||||
dmu_dipole_mean, dmu_dipole_std, sample_alpha,
|
||||
sample_dmu_dipole, name):
|
||||
"""Sample simple calibration parameters."""
|
||||
# The low/high are used only to generate an initial guess, they only
|
||||
# have to be in the right order of magnitude so that's why hardcoded.
|
||||
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
||||
|
||||
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
||||
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
|
||||
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
|
||||
|
||||
|
@ -447,13 +409,11 @@ def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
|
|||
|
||||
|
||||
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
|
||||
beta_max, h_min, h_max, rLG_min, rLG_max, no_Vext,
|
||||
sample_Vmono, sample_beta, sample_h, sample_rLG,
|
||||
sample_Vmag_vax):
|
||||
beta_max, sigma_v_min, sigma_v_max, h_min, h_max,
|
||||
rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta,
|
||||
sample_h, sample_rLG, sample_Vmag_vax):
|
||||
"""Sample the flow calibration."""
|
||||
# The low/high are used only to generate an initial guess, they only
|
||||
# have to be in the right order of magnitude so that's why hardcoded.
|
||||
sigma_v = sample("sigma_v", InverseDistribution(100, 1000., 150.))
|
||||
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
|
||||
|
||||
if sample_beta:
|
||||
beta = sample("beta", Uniform(beta_min, beta_max))
|
||||
|
@ -647,10 +607,15 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
|||
Vmono = field_calibration_params["Vmono"]
|
||||
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
|
||||
|
||||
e_mu = distmod_params["e_mu"]
|
||||
|
||||
# Jeffrey's prior on sigma_v and the intrinsic scatter, they are above
|
||||
# "sampled" from uniform distributions.
|
||||
ll0 -= jnp.log(sigma_v) + jnp.log(e_mu)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# 1. Sample true observables and obtain the distance estimate
|
||||
# ------------------------------------------------------------
|
||||
e_mu = distmod_params["e_mu"]
|
||||
if self.kind == "SN":
|
||||
mag_cal = distmod_params["mag_cal"]
|
||||
alpha_cal = distmod_params["alpha_cal"]
|
||||
|
@ -664,6 +629,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
|||
c_mean, c_std = sample_gaussian_hyperprior(
|
||||
"c", self.name, self.c_min, self.c_max)
|
||||
|
||||
# Jeffrey's prior on the the MNR hyperprior widths.
|
||||
ll0 -= jnp.log(mag_std) + jnp.log(x1_std) + jnp.log(c_std)
|
||||
|
||||
# NOTE: that the true variables are currently uncorrelated.
|
||||
with plate(f"true_SN_{self.name}", self.ndata):
|
||||
mag_true = sample(
|
||||
|
@ -715,6 +683,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
|||
corr_mag_eta = sample(
|
||||
f"corr_mag_eta_{self.name}", Uniform(-1, 1))
|
||||
|
||||
# Jeffrey's prior on the the MNR hyperprior widths.
|
||||
ll0 -= jnp.log(mag_std) + jnp.log(eta_std)
|
||||
|
||||
loc = jnp.array([mag_mean, eta_mean])
|
||||
cov = jnp.array(
|
||||
[[mag_std**2, corr_mag_eta * mag_std * eta_std],
|
||||
|
|
Loading…
Add table
Reference in a new issue