diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index e2e0865..1139e27 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -31,7 +31,9 @@ from jax import jit from jax import numpy as jnp from jax.scipy.special import erf, logsumexp from numpyro import factor, plate, sample -from numpyro.distributions import MultivariateNormal, Normal, Uniform +from numpyro.distributions import (Distribution, MultivariateNormal, Normal, + Uniform, constraints) +from numpyro.distributions.util import promote_shapes from quadax import simpson from tqdm import trange @@ -150,6 +152,35 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax): return normal_logpdf(x, loc, scale) - jnp.log(norm) +############################################################################### +# Inverse distribution # +############################################################################### + + +class InverseDistribution(Distribution): + """Inverse distribution.""" + support = constraints.positive + args_contraints = { + "low": constraints.positive, + "high": constraints.positive, + "fiducial_scale": constraints.positive, + } + reparametrized_params = ["low", "high"] + + def __init__(self, low, high, fiducial_scale=1.): + self.low, self.high = promote_shapes(low, high) + self.log_fiducial_scale = jnp.log(fiducial_scale) + + super(InverseDistribution, self).__init__(batch_shape=()) + + def sample(self, key, sample_shape=()): + z = Uniform(0, 1).sample(key, sample_shape) + return self.high * (self.low / self.high)**z + + def log_prob(self, value): + return - (jnp.log(value) - self.log_fiducial_scale) + + ############################################################################### # Base flow validation # ############################################################################### @@ -314,11 +345,13 @@ def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic): + e_mu_intrinsic**2) -def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean, - alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max, +def sample_SN(mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal_std, + beta_cal_mean, beta_cal_std, alpha_min, alpha_max, sample_alpha, name): """Sample SNIe Tripp parameters.""" - e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) + # The low/high are used only to generate an initial guess, they only + # have to be in the right order of magnitude so that's why hardcoded. + e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15)) mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std)) alpha_cal = sample( f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std)) @@ -350,15 +383,18 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic): return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2 -def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, - c_mean, c_std, alpha_min, alpha_max, sample_alpha, - a_dipole_mean, a_dipole_std, sample_a_dipole, name): +def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min, + alpha_max, sample_alpha, a_dipole_mean, a_dipole_std, + sample_a_dipole, name): """Sample Tully-Fisher calibration parameters.""" - e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) + # The low/high are used only to generate an initial guess, they only + # have to be in the right order of magnitude so that's why hardcoded. + e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15)) a = sample(f"a_{name}", Normal(a_mean, a_std)) if sample_a_dipole: - ax, ay, az = sample(f"a_dipole_{name}", Normal(a_dipole_mean, a_dipole_std).expand([3])) # noqa + ax, ay, az = sample(f"a_dipole_{name}", Normal( + a_dipole_mean, a_dipole_std).expand([3])) else: ax, ay, az = 0.0, 0.0, 0.0 @@ -381,11 +417,13 @@ def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std, # Simple calibration parameters sampling # ############################################################################### -def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max, - dmu_dipole_mean, dmu_dipole_std, sample_alpha, - sample_dmu_dipole, name): +def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean, + dmu_dipole_std, sample_alpha, sample_dmu_dipole, name): """Sample simple calibration parameters.""" - e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max)) + # The low/high are used only to generate an initial guess, they only + # have to be in the right order of magnitude so that's why hardcoded. + e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15)) + dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max)) alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha) @@ -409,11 +447,13 @@ def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max, def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, - beta_max, sigma_v_min, sigma_v_max, h_min, h_max, - rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta, - sample_h, sample_rLG, sample_Vmag_vax): + beta_max, h_min, h_max, rLG_min, rLG_max, no_Vext, + sample_Vmono, sample_beta, sample_h, sample_rLG, + sample_Vmag_vax): """Sample the flow calibration.""" - sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max)) + # The low/high are used only to generate an initial guess, they only + # have to be in the right order of magnitude so that's why hardcoded. + sigma_v = sample("sigma_v", InverseDistribution(100, 1000., 150.)) if sample_beta: beta = sample("beta", Uniform(beta_min, beta_max)) diff --git a/scripts/flow/flow_validation.py b/scripts/flow/flow_validation.py index c560a19..454d88c 100644 --- a/scripts/flow/flow_validation.py +++ b/scripts/flow/flow_validation.py @@ -190,6 +190,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, neg_ln_evidence = jax.numpy.nan neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan) + # Temporarily disable saving. + return fname = join(out_folder, fname) print(f"Saving results to `{fname}`.") with File(fname, "w") as f: @@ -236,16 +238,14 @@ def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole): alpha_max = 10.0 if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa - return {"e_mu_min": 0.001, "e_mu_max": 1.0, - "mag_cal_mean": -18.25, "mag_cal_std": 2.0, + return {"mag_cal_mean": -18.25, "mag_cal_std": 2.0, "alpha_cal_mean": 0.148, "alpha_cal_std": 1.0, "beta_cal_mean": 3.112, "beta_cal_std": 2.0, "alpha_min": alpha_min, "alpha_max": alpha_max, "sample_alpha": sample_alpha } elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue: - return {"e_mu_min": 0.001, "e_mu_max": 1.0, - "a_mean": -21., "a_std": 5.0, + return {"a_mean": -21., "a_std": 5.0, "b_mean": -5.95, "b_std": 4.0, "c_mean": 0., "c_std": 20.0, "a_dipole_mean": 0., "a_dipole_std": 1.0, @@ -254,8 +254,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole): "sample_alpha": sample_alpha, } elif catalogue in ["CF4_GroupAll"]: - return {"e_mu_min": 0.001, "e_mu_max": 1.0, - "dmu_min": -3.0, "dmu_max": 3.0, + return {"dmu_min": -3.0, "dmu_max": 3.0, "dmu_dipole_mean": 0., "dmu_dipole_std": 1.0, "sample_dmu_dipole": sample_mag_dipole, "alpha_min": alpha_min, "alpha_max": alpha_max, @@ -299,16 +298,16 @@ if __name__ == "__main__": ########################################################################### # `None` means default behaviour - nsteps = 10_000 + nsteps = 2_000 nburn = 2_000 zcmb_min = None zcmb_max = 0.05 nchains_harmonic = 10 num_epochs = 50 - inference_method = "mike" + inference_method = "bayes" mag_selection = None sample_alpha = False if "IndranilVoid_" in ARGS.simname or ARGS.simname == "no_field" else True # noqa - sample_beta = None + sample_beta = True no_Vext = None sample_Vmag_vax = False sample_Vmono = False @@ -377,7 +376,6 @@ if __name__ == "__main__": calibration_hyperparams = {"Vext_min": -3000, "Vext_max": 3000, "Vmono_min": -1000, "Vmono_max": 1000, "beta_min": -10.0, "beta_max": 10.0, - "sigma_v_min": 1.0, "sigma_v_max": 5000 if "IndranilVoid_" in ARGS.simname else 750., # noqa "h_min": 0.01, "h_max": 5.0, "no_Vext": False if no_Vext is None else no_Vext, # noqa "sample_Vmag_vax": sample_Vmag_vax,