mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2025-04-19 06:40:54 +00:00
Switch to Jeffrey
This commit is contained in:
parent
f7d16a9cd0
commit
498cca52b7
2 changed files with 65 additions and 27 deletions
|
@ -31,7 +31,9 @@ from jax import jit
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax.scipy.special import erf, logsumexp
|
from jax.scipy.special import erf, logsumexp
|
||||||
from numpyro import factor, plate, sample
|
from numpyro import factor, plate, sample
|
||||||
from numpyro.distributions import MultivariateNormal, Normal, Uniform
|
from numpyro.distributions import (Distribution, MultivariateNormal, Normal,
|
||||||
|
Uniform, constraints)
|
||||||
|
from numpyro.distributions.util import promote_shapes
|
||||||
from quadax import simpson
|
from quadax import simpson
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
@ -150,6 +152,35 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
|
||||||
return normal_logpdf(x, loc, scale) - jnp.log(norm)
|
return normal_logpdf(x, loc, scale) - jnp.log(norm)
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Inverse distribution #
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
class InverseDistribution(Distribution):
|
||||||
|
"""Inverse distribution."""
|
||||||
|
support = constraints.positive
|
||||||
|
args_contraints = {
|
||||||
|
"low": constraints.positive,
|
||||||
|
"high": constraints.positive,
|
||||||
|
"fiducial_scale": constraints.positive,
|
||||||
|
}
|
||||||
|
reparametrized_params = ["low", "high"]
|
||||||
|
|
||||||
|
def __init__(self, low, high, fiducial_scale=1.):
|
||||||
|
self.low, self.high = promote_shapes(low, high)
|
||||||
|
self.log_fiducial_scale = jnp.log(fiducial_scale)
|
||||||
|
|
||||||
|
super(InverseDistribution, self).__init__(batch_shape=())
|
||||||
|
|
||||||
|
def sample(self, key, sample_shape=()):
|
||||||
|
z = Uniform(0, 1).sample(key, sample_shape)
|
||||||
|
return self.high * (self.low / self.high)**z
|
||||||
|
|
||||||
|
def log_prob(self, value):
|
||||||
|
return - (jnp.log(value) - self.log_fiducial_scale)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Base flow validation #
|
# Base flow validation #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -314,11 +345,13 @@ def e2_distmod_SN(e2_mag, e2_x1, e2_c, alpha_cal, beta_cal, e_mu_intrinsic):
|
||||||
+ e_mu_intrinsic**2)
|
+ e_mu_intrinsic**2)
|
||||||
|
|
||||||
|
|
||||||
def sample_SN(e_mu_min, e_mu_max, mag_cal_mean, mag_cal_std, alpha_cal_mean,
|
def sample_SN(mag_cal_mean, mag_cal_std, alpha_cal_mean, alpha_cal_std,
|
||||||
alpha_cal_std, beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
|
beta_cal_mean, beta_cal_std, alpha_min, alpha_max,
|
||||||
sample_alpha, name):
|
sample_alpha, name):
|
||||||
"""Sample SNIe Tripp parameters."""
|
"""Sample SNIe Tripp parameters."""
|
||||||
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
# The low/high are used only to generate an initial guess, they only
|
||||||
|
# have to be in the right order of magnitude so that's why hardcoded.
|
||||||
|
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
||||||
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
|
mag_cal = sample(f"mag_cal_{name}", Normal(mag_cal_mean, mag_cal_std))
|
||||||
alpha_cal = sample(
|
alpha_cal = sample(
|
||||||
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
|
f"alpha_cal_{name}", Normal(alpha_cal_mean, alpha_cal_std))
|
||||||
|
@ -350,15 +383,18 @@ def e2_distmod_TFR(e2_mag, e2_eta, eta, b, c, e_mu_intrinsic):
|
||||||
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
|
return e2_mag + (b + 2 * c * eta)**2 * e2_eta + e_mu_intrinsic**2
|
||||||
|
|
||||||
|
|
||||||
def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
|
def sample_TFR(a_mean, a_std, b_mean, b_std, c_mean, c_std, alpha_min,
|
||||||
c_mean, c_std, alpha_min, alpha_max, sample_alpha,
|
alpha_max, sample_alpha, a_dipole_mean, a_dipole_std,
|
||||||
a_dipole_mean, a_dipole_std, sample_a_dipole, name):
|
sample_a_dipole, name):
|
||||||
"""Sample Tully-Fisher calibration parameters."""
|
"""Sample Tully-Fisher calibration parameters."""
|
||||||
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
# The low/high are used only to generate an initial guess, they only
|
||||||
|
# have to be in the right order of magnitude so that's why hardcoded.
|
||||||
|
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
||||||
a = sample(f"a_{name}", Normal(a_mean, a_std))
|
a = sample(f"a_{name}", Normal(a_mean, a_std))
|
||||||
|
|
||||||
if sample_a_dipole:
|
if sample_a_dipole:
|
||||||
ax, ay, az = sample(f"a_dipole_{name}", Normal(a_dipole_mean, a_dipole_std).expand([3])) # noqa
|
ax, ay, az = sample(f"a_dipole_{name}", Normal(
|
||||||
|
a_dipole_mean, a_dipole_std).expand([3]))
|
||||||
else:
|
else:
|
||||||
ax, ay, az = 0.0, 0.0, 0.0
|
ax, ay, az = 0.0, 0.0, 0.0
|
||||||
|
|
||||||
|
@ -381,11 +417,13 @@ def sample_TFR(e_mu_min, e_mu_max, a_mean, a_std, b_mean, b_std,
|
||||||
# Simple calibration parameters sampling #
|
# Simple calibration parameters sampling #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
|
def sample_simple(dmu_min, dmu_max, alpha_min, alpha_max, dmu_dipole_mean,
|
||||||
dmu_dipole_mean, dmu_dipole_std, sample_alpha,
|
dmu_dipole_std, sample_alpha, sample_dmu_dipole, name):
|
||||||
sample_dmu_dipole, name):
|
|
||||||
"""Sample simple calibration parameters."""
|
"""Sample simple calibration parameters."""
|
||||||
e_mu = sample(f"e_mu_{name}", Uniform(e_mu_min, e_mu_max))
|
# The low/high are used only to generate an initial guess, they only
|
||||||
|
# have to be in the right order of magnitude so that's why hardcoded.
|
||||||
|
e_mu = sample(f"e_mu_{name}", InverseDistribution(0.2, 0.5, 0.15))
|
||||||
|
|
||||||
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
|
dmu = sample(f"dmu_{name}", Uniform(dmu_min, dmu_max))
|
||||||
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
|
alpha = sample_alpha_bias(name, alpha_min, alpha_max, sample_alpha)
|
||||||
|
|
||||||
|
@ -409,11 +447,13 @@ def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
|
||||||
|
|
||||||
|
|
||||||
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
|
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
|
||||||
beta_max, sigma_v_min, sigma_v_max, h_min, h_max,
|
beta_max, h_min, h_max, rLG_min, rLG_max, no_Vext,
|
||||||
rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta,
|
sample_Vmono, sample_beta, sample_h, sample_rLG,
|
||||||
sample_h, sample_rLG, sample_Vmag_vax):
|
sample_Vmag_vax):
|
||||||
"""Sample the flow calibration."""
|
"""Sample the flow calibration."""
|
||||||
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
|
# The low/high are used only to generate an initial guess, they only
|
||||||
|
# have to be in the right order of magnitude so that's why hardcoded.
|
||||||
|
sigma_v = sample("sigma_v", InverseDistribution(100, 1000., 150.))
|
||||||
|
|
||||||
if sample_beta:
|
if sample_beta:
|
||||||
beta = sample("beta", Uniform(beta_min, beta_max))
|
beta = sample("beta", Uniform(beta_min, beta_max))
|
||||||
|
|
|
@ -190,6 +190,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder,
|
||||||
neg_ln_evidence = jax.numpy.nan
|
neg_ln_evidence = jax.numpy.nan
|
||||||
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
||||||
|
|
||||||
|
# Temporarily disable saving.
|
||||||
|
return
|
||||||
fname = join(out_folder, fname)
|
fname = join(out_folder, fname)
|
||||||
print(f"Saving results to `{fname}`.")
|
print(f"Saving results to `{fname}`.")
|
||||||
with File(fname, "w") as f:
|
with File(fname, "w") as f:
|
||||||
|
@ -236,16 +238,14 @@ def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
|
||||||
alpha_max = 10.0
|
alpha_max = 10.0
|
||||||
|
|
||||||
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
|
if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
|
||||||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
return {"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
|
||||||
"mag_cal_mean": -18.25, "mag_cal_std": 2.0,
|
|
||||||
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
|
"alpha_cal_mean": 0.148, "alpha_cal_std": 1.0,
|
||||||
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
|
"beta_cal_mean": 3.112, "beta_cal_std": 2.0,
|
||||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||||
"sample_alpha": sample_alpha
|
"sample_alpha": sample_alpha
|
||||||
}
|
}
|
||||||
elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue:
|
elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue:
|
||||||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
return {"a_mean": -21., "a_std": 5.0,
|
||||||
"a_mean": -21., "a_std": 5.0,
|
|
||||||
"b_mean": -5.95, "b_std": 4.0,
|
"b_mean": -5.95, "b_std": 4.0,
|
||||||
"c_mean": 0., "c_std": 20.0,
|
"c_mean": 0., "c_std": 20.0,
|
||||||
"a_dipole_mean": 0., "a_dipole_std": 1.0,
|
"a_dipole_mean": 0., "a_dipole_std": 1.0,
|
||||||
|
@ -254,8 +254,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
|
||||||
"sample_alpha": sample_alpha,
|
"sample_alpha": sample_alpha,
|
||||||
}
|
}
|
||||||
elif catalogue in ["CF4_GroupAll"]:
|
elif catalogue in ["CF4_GroupAll"]:
|
||||||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
return {"dmu_min": -3.0, "dmu_max": 3.0,
|
||||||
"dmu_min": -3.0, "dmu_max": 3.0,
|
|
||||||
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
|
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
|
||||||
"sample_dmu_dipole": sample_mag_dipole,
|
"sample_dmu_dipole": sample_mag_dipole,
|
||||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||||
|
@ -299,16 +298,16 @@ if __name__ == "__main__":
|
||||||
###########################################################################
|
###########################################################################
|
||||||
|
|
||||||
# `None` means default behaviour
|
# `None` means default behaviour
|
||||||
nsteps = 10_000
|
nsteps = 2_000
|
||||||
nburn = 2_000
|
nburn = 2_000
|
||||||
zcmb_min = None
|
zcmb_min = None
|
||||||
zcmb_max = 0.05
|
zcmb_max = 0.05
|
||||||
nchains_harmonic = 10
|
nchains_harmonic = 10
|
||||||
num_epochs = 50
|
num_epochs = 50
|
||||||
inference_method = "mike"
|
inference_method = "bayes"
|
||||||
mag_selection = None
|
mag_selection = None
|
||||||
sample_alpha = False if "IndranilVoid_" in ARGS.simname or ARGS.simname == "no_field" else True # noqa
|
sample_alpha = False if "IndranilVoid_" in ARGS.simname or ARGS.simname == "no_field" else True # noqa
|
||||||
sample_beta = None
|
sample_beta = True
|
||||||
no_Vext = None
|
no_Vext = None
|
||||||
sample_Vmag_vax = False
|
sample_Vmag_vax = False
|
||||||
sample_Vmono = False
|
sample_Vmono = False
|
||||||
|
@ -377,7 +376,6 @@ if __name__ == "__main__":
|
||||||
calibration_hyperparams = {"Vext_min": -3000, "Vext_max": 3000,
|
calibration_hyperparams = {"Vext_min": -3000, "Vext_max": 3000,
|
||||||
"Vmono_min": -1000, "Vmono_max": 1000,
|
"Vmono_min": -1000, "Vmono_max": 1000,
|
||||||
"beta_min": -10.0, "beta_max": 10.0,
|
"beta_min": -10.0, "beta_max": 10.0,
|
||||||
"sigma_v_min": 1.0, "sigma_v_max": 5000 if "IndranilVoid_" in ARGS.simname else 750., # noqa
|
|
||||||
"h_min": 0.01, "h_max": 5.0,
|
"h_min": 0.01, "h_max": 5.0,
|
||||||
"no_Vext": False if no_Vext is None else no_Vext, # noqa
|
"no_Vext": False if no_Vext is None else no_Vext, # noqa
|
||||||
"sample_Vmag_vax": sample_Vmag_vax,
|
"sample_Vmag_vax": sample_Vmag_vax,
|
||||||
|
|
Loading…
Add table
Reference in a new issue