Plots of VF (#134)

* Add VF plots

* Update nb

* Add CMB velocity note

* rm nb

* Add option to return alllikelihood

* Add simulation weights

* Update nb

* Add bulkflow

* Update nb

* Add values of beta

* Update imports

* Update imports

* Add paths to Carrick and Lilow fiels

* Add Carrick and Lilow fields

* Add support for more fields

* Update bulkflow comp

* Update nb

* Update script
This commit is contained in:
Richard Stiskalek 2024-07-01 11:48:50 +01:00 committed by GitHub
parent 7dad6885e8
commit c6f49790bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1208 additions and 2680 deletions

View file

@ -48,18 +48,18 @@ def parse_args():
ARGS = parse_args()
# This must be done before we import JAX etc.
from numpyro import set_host_device_count, set_platform # noqa
set_platform(ARGS.device) # noqa
set_host_device_count(ARGS.ndevice) # noqa
import sys # noqa
from os.path import join # noqa
import csiborgtools # noqa
import jax # noqa
from h5py import File # noqa
from mpi4py import MPI # noqa
from numpyro.infer import MCMC, NUTS, init_to_median # noqa
import csiborgtools # noqa
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median # noqa
def print_variables(names, variables):
@ -113,6 +113,19 @@ def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
def get_simulation_weights(samples, model, model_kwargs):
"""Get the weights per posterior samples for each simulation."""
predictive = Predictive(model, samples)
ll_all = predictive(
jax.random.PRNGKey(1), store_ll_all=True, **model_kwargs)["ll_all"]
# Multiply the likelihood of galaxies
ll_per_simulation = jax.numpy.sum(ll_all, axis=-1)
# Normalization by summing the likelihood over simulations
norm = jax.scipy.special.logsumexp(ll_per_simulation, axis=-1)
return ll_per_simulation - norm[:, None]
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
"""Run the NumPyro model and save output to a file."""
@ -128,6 +141,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
samples = mcmc.get_samples()
simulation_weights = get_simulation_weights(samples, model, model_kwargs)
log_posterior = -mcmc.get_extra_fields()["potential_energy"]
log_likelihood = samples.pop("ll_values")
@ -141,13 +155,13 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
if calculate_evidence:
print("Calculating the evidence using `harmonic`.", flush=True)
ln_evidence, ln_evidence_err = get_harmonic_evidence(
neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence(
samples, log_posterior, nchains_harmonic, epoch_num)
print(f"{'ln(Z)':<20} {ln_evidence}")
print(f"{'ln(Z) error':<20} {ln_evidence_err}")
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
else:
ln_evidence = jax.numpy.nan
ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
neg_ln_evidence = jax.numpy.nan
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
fname = f"samples_{ARGS.simname}_{ARGS.catalogue}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
if ARGS.ksim is not None:
@ -167,13 +181,14 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
# Write log likelihood and posterior
f.create_dataset("log_likelihood", data=log_likelihood)
f.create_dataset("log_posterior", data=log_posterior)
f.create_dataset("simulation_weights", data=simulation_weights)
# Write goodness of fit
grp = f.create_group("gof")
grp.create_dataset("BIC", data=BIC)
grp.create_dataset("AIC", data=AIC)
grp.create_dataset("lnZ", data=ln_evidence)
grp.create_dataset("lnZ_err", data=ln_evidence_err)
grp.create_dataset("neg_lnZ", data=neg_ln_evidence)
grp.create_dataset("neg_lnZ_err", data=neg_ln_evidence_err)
fname_summary = fname.replace(".hdf5", ".txt")
print(f"Saving summary to `{fname_summary}`.")
@ -188,8 +203,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
print("HMC summary:")
print(f"{'BIC':<20} {BIC}")
print(f"{'AIC':<20} {AIC}")
print(f"{'ln(Z)':<20} {ln_evidence}")
print(f"{'ln(Z) error':<20} {ln_evidence_err}")
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
mcmc.print_summary(exclude_deterministic=False)
sys.stdout = original_stdout
@ -238,7 +253,7 @@ if __name__ == "__main__":
print_variables(
calibration_hyperparams.keys(), calibration_hyperparams.values())
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups"]: # noqa
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
distmod_hyperparams = {"e_mu_mean": 0.1, "e_mu_std": 0.05,
"mag_cal_mean": -18.25, "mag_cal_std": 0.5,
"alpha_cal_mean": 0.148, "alpha_cal_std": 0.05,