diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index 1139e27..4cfdcb1 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -31,9 +31,7 @@ from jax import jit from jax import numpy as jnp from jax.scipy.special import erf, logsumexp from numpyro import factor, plate, sample -from numpyro.distributions import (Distribution, MultivariateNormal, Normal, - Uniform, constraints) -from numpyro.distributions.util import promote_shapes +from numpyro.distributions import MultivariateNormal, Normal, Uniform from quadax import simpson from tqdm import trange @@ -152,35 +150,6 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax): return normal_logpdf(x, loc, scale) - jnp.log(norm) -############################################################################### -# Inverse distribution # -############################################################################### - - -class InverseDistribution(Distribution): - """Inverse distribution.""" - support = constraints.positive - args_contraints = { - "low": constraints.positive, - "high": constraints.positive, - "fiducial_scale": constraints.positive, - } - reparametrized_params = ["low", "high"] - - def __init__(self, low, high, fiducial_scale=1.): - self.low, self.high = promote_shapes(low, high) - self.log_fiducial_scale = jnp.log(fiducial_scale) - - super(InverseDistribution, self).__init__(batch_shape=()) - - def sample(self, key, sample_shape=()): - z = Uniform(0, 1).sample(key, sample_shape) - return self.high * (self.low / self.high)**z - - def log_prob(self, value): - return - (jnp.log(value) - self.log_fiducial_scale) - - ############################################################################### # Base flow validation # ############################################################################### @@ -345,13 +314,11 @@ def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic): + e_mu_intrinsic**2) -def sample_SN(mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal_std, - beta_cal_mean, beta_cal_std, alpha_min, alpha_max, +def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean, + alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max, sample_alpha, name): """Sample SNIe Tripp parameters.""" - # The low/high are used only to generate an initial guess, they only - # have to be in the right order of magnitude so that's why hardcoded. - e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15)) + e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std)) alpha_cal = sample( f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std)) @@ -383,18 +350,15 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic): return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2 -def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min, - alpha_max, sample_alpha, a_dipole_mean, a_dipole_std, - sample_a_dipole, name): +def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, + c_mean, c_std, alpha_min, alpha_max, sample_alpha, + a_dipole_mean, a_dipole_std, sample_a_dipole, name): """Sample Tully-Fisher calibration parameters.""" - # The low/high are used only to generate an initial guess, they only - # have to be in the right order of magnitude so that's why hardcoded. - e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15)) + e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) a = sample(f"a_{name}", Normal(a_mean, a_std)) if sample_a_dipole: - ax, ay, az = sample(f"a_dipole_{name}", Normal( - a_dipole_mean, a_dipole_std).expand([3])) + ax, ay, az = sample(f"a_dipole_{name}", Normal(a_dipole_mean, a_dipole_std).expand([3])) # noqa else: ax, ay, az = 0.0, 0.0, 0.0 @@ -417,13 +381,11 @@ def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min, # Simple calibration parameters sampling # ############################################################################### -def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean, - dmu_dipole_std, sample_alpha, sample_dmu_dipole, name): +def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max, + dmu_dipole_mean, dmu_dipole_std, sample_alpha, + sample_dmu_dipole, name): """Sample simple calibration parameters.""" - # The low/high are used only to generate an initial guess, they only - # have to be in the right order of magnitude so that's why hardcoded. - e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15)) - + e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max)) alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha) @@ -447,13 +409,11 @@ def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean, def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, - beta_max, h_min, h_max, rLG_min, rLG_max, no_Vext, - sample_Vmono, sample_beta, sample_h, sample_rLG, - sample_Vmag_vax): + beta_max, sigma_v_min, sigma_v_max, h_min, h_max, + rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta, + sample_h, sample_rLG, sample_Vmag_vax): """Sample the flow calibration.""" - # The low/high are used only to generate an initial guess, they only - # have to be in the right order of magnitude so that's why hardcoded. - sigma_v = sample("sigma_v", InverseDistribution(100, 1000., 150.)) + sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max)) if sample_beta: beta = sample("beta", Uniform(beta_min, beta_max)) @@ -647,10 +607,15 @@ class PV_LogLikelihood(BaseFlowValidationModel): Vmono = field_calibration_params["Vmono"] Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec) + e_mu = distmod_params["e_mu"] + + # Jeffrey's prior on sigma_v and the intrinsic scatter, they are above + # "sampled" from uniform distributions. + ll0 -= jnp.log(sigma_v) + jnp.log(e_mu) + # ------------------------------------------------------------ # 1. Sample true observables and obtain the distance estimate # ------------------------------------------------------------ - e_mu = distmod_params["e_mu"] if self.kind == "SN": mag_cal = distmod_params["mag_cal"] alpha_cal = distmod_params["alpha_cal"] @@ -664,6 +629,9 @@ class PV_LogLikelihood(BaseFlowValidationModel): c_mean, c_std = sample_gaussian_hyperprior( "c", self.name, self.c_min, self.c_max) + # Jeffrey's prior on the the MNR hyperprior widths. + ll0 -= jnp.log(mag_std) + jnp.log(x1_std) + jnp.log(c_std) + # NOTE: that the true variables are currently uncorrelated. with plate(f"true_SN_{self.name}", self.ndata): mag_true = sample( @@ -715,6 +683,9 @@ class PV_LogLikelihood(BaseFlowValidationModel): corr_mag_eta = sample( f"corr_mag_eta_{self.name}", Uniform(-1, 1)) + # Jeffrey's prior on the the MNR hyperprior widths. + ll0 -= jnp.log(mag_std) + jnp.log(eta_std) + loc = jnp.array([mag_mean, eta_mean]) cov = jnp.array( [[mag_std**2, corr_mag_eta * mag_std * eta_std],