Add more about evidence and selection to flow (#142)

* Add Laplace evidence

* Numerically stable laplace evidence

* Minor edits to Laplace

* Remove rmax

* Rm old things

* Rm comments

* Add script

* Add super toy selection

* Add super toy selection

* Update script
This commit is contained in:
Richard Stiskalek 2024-08-27 00:36:00 +02:00 committed by GitHub
parent d13246a394
commit 3d1e1c0ae3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 243 additions and 57 deletions

View file

@ -10,7 +10,7 @@ MAS="SPH"
grid=1024
for simname in "Carrick2015"; do
for simname in "Lilow2024"; do
for catalogue in "CF4_TFR"; do
pythoncm="$env $file --catalogue $catalogue --nsims $nsims --simname $simname --MAS $MAS --grid $grid"
if [ $on_login -eq 1 ]; then

View file

@ -72,7 +72,7 @@ def print_variables(names, variables):
print(flush=True)
def get_models(get_model_kwargs, verbose=True):
def get_models(get_model_kwargs, toy_selection, verbose=True):
"""Load the data and create the NumPyro models."""
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
folder = "/mnt/extraspace/rstiskalek/catalogs/"
@ -110,7 +110,8 @@ def get_models(get_model_kwargs, verbose=True):
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
cat, fpath, paths,
ksmooth=ARGS.ksmooth)
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
models[i] = csiborgtools.flow.get_model(
loader, toy_selection=toy_selection[i], **get_model_kwargs)
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
return models
@ -127,7 +128,7 @@ def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
calculate_harmonic, nchains_harmonic, epoch_num, kwargs_print):
"""Run the NumPyro model and save output to a file."""
try:
ndata = sum(model.ndata for model in model_kwargs["models"])
@ -148,12 +149,12 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
print(f"{'AIC':<20} {AIC}")
mcmc.print_summary()
if calculate_evidence:
if calculate_harmonic:
print("Calculating the evidence using `harmonic`.", flush=True)
neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence(
samples, log_posterior, nchains_harmonic, epoch_num)
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
print(f"{'-ln(Z_h)':<20} {neg_ln_evidence}")
print(f"{'-ln(Z_h) error':<20} {neg_ln_evidence_err}")
else:
neg_ln_evidence = jax.numpy.nan
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
@ -180,8 +181,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
grp = f.create_group("gof")
grp.create_dataset("BIC", data=BIC)
grp.create_dataset("AIC", data=AIC)
grp.create_dataset("neg_lnZ", data=neg_ln_evidence)
grp.create_dataset("neg_lnZ_err", data=neg_ln_evidence_err)
grp.create_dataset("neg_lnZ_harmonic", data=neg_ln_evidence)
grp.create_dataset("neg_lnZ_harmonic_err", data=neg_ln_evidence_err)
fname_summary = fname.replace(".hdf5", ".txt")
print(f"Saving summary to `{fname_summary}`.")
@ -206,7 +207,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Command line interface #
###############################################################################
def get_distmod_hyperparams(catalogue, sample_alpha):
def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
alpha_min = -1.0
alpha_max = 3.0
@ -225,7 +226,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
"c_mean": 0., "c_std": 20.0,
"sample_curvature": False,
"a_dipole_mean": 0., "a_dipole_std": 1.0,
"sample_a_dipole": True,
"sample_a_dipole": sample_mag_dipole,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha,
}
@ -233,7 +234,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
"dmu_min": -3.0, "dmu_max": 3.0,
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
"sample_dmu_dipole": True,
"sample_dmu_dipole": sample_mag_dipole,
"alpha_min": alpha_min, "alpha_max": alpha_max,
"sample_alpha": sample_alpha,
}
@ -241,6 +242,16 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
def get_toy_selection(toy_selection, catalogue):
if not toy_selection:
return None
if catalogue == "SFI_gals":
return [1.221e+01, 1.297e+01, -2.708e-01]
else:
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
if __name__ == "__main__":
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
out_folder = "/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity" # noqa
@ -251,18 +262,23 @@ if __name__ == "__main__":
# Fixed user parameters #
###########################################################################
nsteps = 1500
nburn = 1000
nsteps = 1000
nburn = 500
zcmb_min = 0
zcmb_max = 0.05
calculate_evidence = False
nchains_harmonic = 10
num_epochs = 30
inference_method = "mike"
num_epochs = 50
inference_method = "bayes"
calculate_harmonic = True if inference_method == "mike" else False
maxmag_selection = None
sample_alpha = True
sample_alpha = False
sample_beta = True
sample_Vmono = False
sample_mag_dipole = False
toy_selection = True
if toy_selection and inference_method == "mike":
raise ValueError("Toy selection is not supported with `mike` inference.") # noqa
if nsteps % nchains_harmonic != 0:
raise ValueError(
@ -272,10 +288,12 @@ if __name__ == "__main__":
"zcmb_min": zcmb_min,
"zcmb_max": zcmb_max,
"maxmag_selection": maxmag_selection,
"calculate_evidence": calculate_evidence,
"calculate_harmonic": calculate_harmonic,
"nchains_harmonic": nchains_harmonic,
"num_epochs": num_epochs,
"inference_method": inference_method}
"inference_method": inference_method,
"sample_mag_dipole": sample_mag_dipole,
"toy_selection": toy_selection}
print_variables(main_params.keys(), main_params.values())
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
@ -290,7 +308,7 @@ if __name__ == "__main__":
distmod_hyperparams_per_catalogue = []
for cat in ARGS.catalogue:
x = get_distmod_hyperparams(cat, sample_alpha)
x = get_distmod_hyperparams(cat, sample_alpha, sample_mag_dipole)
print(f"\n{cat} hyperparameters:")
print_variables(x.keys(), x.values())
distmod_hyperparams_per_catalogue.append(x)
@ -301,7 +319,11 @@ if __name__ == "__main__":
get_model_kwargs = {"zcmb_min": zcmb_min, "zcmb_max": zcmb_max,
"maxmag_selection": maxmag_selection}
models = get_models(get_model_kwargs, )
toy_selection = [get_toy_selection(toy_selection, cat)
for cat in ARGS.catalogue]
models = get_models(get_model_kwargs, toy_selection)
model_kwargs = {
"models": models,
"field_calibration_hyperparams": calibration_hyperparams,
@ -312,5 +334,5 @@ if __name__ == "__main__":
model = csiborgtools.flow.PV_validation_model
run_model(model, nsteps, nburn, model_kwargs, out_folder,
calibration_hyperparams["sample_beta"], calculate_evidence,
calibration_hyperparams["sample_beta"], calculate_harmonic,
nchains_harmonic, num_epochs, kwargs_print)

View file

@ -39,7 +39,7 @@ fi
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
for simname in "Carrick2015"; do
for catalogue in "CF4_GroupAll"; do
for catalogue in "SFI_gals"; do
# for catalogue in "CF4_TFR_i"; do
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
for ksim in "none"; do

View file

@ -0,0 +1,60 @@
from argparse import ArgumentParser, ArgumentTypeError
def parse_args():
parser = ArgumentParser()
parser.add_argument("--device", type=str, default="cpu",
help="Device to use.")
return parser.parse_args()
ARGS = parse_args()
# This must be done before we import JAX etc.
from numpyro import set_host_device_count, set_platform # noqa
set_platform(ARGS.device) # noqa
from jax import numpy as jnp # noqa
import numpy as np # noqa
import csiborgtools # noqa
from scipy.stats import multivariate_normal # noqa
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
"""Compute evidence using the `harmonic` package."""
data, names = csiborgtools.dict_samples_to_array(samples)
data = data.reshape(nchains_harmonic, -1, len(names))
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
return csiborgtools.harmonic_evidence(
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
ndim = 250
nsamples = 100_000
nchains_split = 10
loc = jnp.zeros(ndim)
cov = jnp.eye(ndim)
gen = np.random.default_rng()
X = gen.multivariate_normal(loc, cov, size=nsamples)
samples = {f"x_{i}": X[:, i] for i in range(ndim)}
logprob = multivariate_normal(loc, cov).logpdf(X)
neg_lnZ_laplace, neg_lnZ_laplace_error = csiborgtools.laplace_evidence(
samples, logprob, nchains_split)
print(f"neg_lnZ_laplace: {neg_lnZ_laplace} +/- {neg_lnZ_laplace_error}")
neg_lnZ_harmonic, neg_lnZ_harmonic_error = get_harmonic_evidence(
samples, logprob, nchains_split, epoch_num=30)
print(f"neg_lnZ_harmonic: {neg_lnZ_harmonic} +/- {neg_lnZ_harmonic_error}")