mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-05-22 02:21:11 +00:00
Plots of VF (#134)
* Add VF plots * Update nb * Add CMB velocity note * rm nb * Add option to return alllikelihood * Add simulation weights * Update nb * Add bulkflow * Update nb * Add values of beta * Update imports * Update imports * Add paths to Carrick and Lilow fiels * Add Carrick and Lilow fields * Add support for more fields * Update bulkflow comp * Update nb * Update script
This commit is contained in:
parent
7dad6885e8
commit
c6f49790bf
13 changed files with 1208 additions and 2680 deletions
|
@ -48,18 +48,18 @@ def parse_args():
|
|||
ARGS = parse_args()
|
||||
# This must be done before we import JAX etc.
|
||||
from numpyro import set_host_device_count, set_platform # noqa
|
||||
|
||||
set_platform(ARGS.device) # noqa
|
||||
set_host_device_count(ARGS.ndevice) # noqa
|
||||
|
||||
import sys # noqa
|
||||
from os.path import join # noqa
|
||||
|
||||
import csiborgtools # noqa
|
||||
import jax # noqa
|
||||
from h5py import File # noqa
|
||||
from mpi4py import MPI # noqa
|
||||
from numpyro.infer import MCMC, NUTS, init_to_median # noqa
|
||||
|
||||
import csiborgtools # noqa
|
||||
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median # noqa
|
||||
|
||||
|
||||
def print_variables(names, variables):
|
||||
|
@ -113,6 +113,19 @@ def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
|||
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
|
||||
|
||||
|
||||
def get_simulation_weights(samples, model, model_kwargs):
|
||||
"""Get the weights per posterior samples for each simulation."""
|
||||
predictive = Predictive(model, samples)
|
||||
ll_all = predictive(
|
||||
jax.random.PRNGKey(1), store_ll_all=True, **model_kwargs)["ll_all"]
|
||||
|
||||
# Multiply the likelihood of galaxies
|
||||
ll_per_simulation = jax.numpy.sum(ll_all, axis=-1)
|
||||
# Normalization by summing the likelihood over simulations
|
||||
norm = jax.scipy.special.logsumexp(ll_per_simulation, axis=-1)
|
||||
return ll_per_simulation - norm[:, None]
|
||||
|
||||
|
||||
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
|
||||
"""Run the NumPyro model and save output to a file."""
|
||||
|
@ -128,6 +141,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
|
||||
mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs)
|
||||
samples = mcmc.get_samples()
|
||||
simulation_weights = get_simulation_weights(samples, model, model_kwargs)
|
||||
|
||||
log_posterior = -mcmc.get_extra_fields()["potential_energy"]
|
||||
log_likelihood = samples.pop("ll_values")
|
||||
|
@ -141,13 +155,13 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
|
||||
if calculate_evidence:
|
||||
print("Calculating the evidence using `harmonic`.", flush=True)
|
||||
ln_evidence, ln_evidence_err = get_harmonic_evidence(
|
||||
neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence(
|
||||
samples, log_posterior, nchains_harmonic, epoch_num)
|
||||
print(f"{'ln(Z)':<20} {ln_evidence}")
|
||||
print(f"{'ln(Z) error':<20} {ln_evidence_err}")
|
||||
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
|
||||
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
|
||||
else:
|
||||
ln_evidence = jax.numpy.nan
|
||||
ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
||||
neg_ln_evidence = jax.numpy.nan
|
||||
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
||||
|
||||
fname = f"samples_{ARGS.simname}_{ARGS.catalogue}_ksmooth{ARGS.ksmooth}.hdf5" # noqa
|
||||
if ARGS.ksim is not None:
|
||||
|
@ -167,13 +181,14 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
# Write log likelihood and posterior
|
||||
f.create_dataset("log_likelihood", data=log_likelihood)
|
||||
f.create_dataset("log_posterior", data=log_posterior)
|
||||
f.create_dataset("simulation_weights", data=simulation_weights)
|
||||
|
||||
# Write goodness of fit
|
||||
grp = f.create_group("gof")
|
||||
grp.create_dataset("BIC", data=BIC)
|
||||
grp.create_dataset("AIC", data=AIC)
|
||||
grp.create_dataset("lnZ", data=ln_evidence)
|
||||
grp.create_dataset("lnZ_err", data=ln_evidence_err)
|
||||
grp.create_dataset("neg_lnZ", data=neg_ln_evidence)
|
||||
grp.create_dataset("neg_lnZ_err", data=neg_ln_evidence_err)
|
||||
|
||||
fname_summary = fname.replace(".hdf5", ".txt")
|
||||
print(f"Saving summary to `{fname_summary}`.")
|
||||
|
@ -188,8 +203,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
print("HMC summary:")
|
||||
print(f"{'BIC':<20} {BIC}")
|
||||
print(f"{'AIC':<20} {AIC}")
|
||||
print(f"{'ln(Z)':<20} {ln_evidence}")
|
||||
print(f"{'ln(Z) error':<20} {ln_evidence_err}")
|
||||
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
|
||||
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
|
||||
mcmc.print_summary(exclude_deterministic=False)
|
||||
sys.stdout = original_stdout
|
||||
|
||||
|
@ -238,7 +253,7 @@ if __name__ == "__main__":
|
|||
print_variables(
|
||||
calibration_hyperparams.keys(), calibration_hyperparams.values())
|
||||
|
||||
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups"]: # noqa
|
||||
if ARGS.catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa
|
||||
distmod_hyperparams = {"e_mu_mean": 0.1, "e_mu_std": 0.05,
|
||||
"mag_cal_mean": -18.25, "mag_cal_std": 0.5,
|
||||
"alpha_cal_mean": 0.148, "alpha_cal_std": 0.05,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue