mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-07-02 20:41:11 +00:00
Add more about evidence and selection to flow (#142)
* Add Laplace evidence * Numerically stable laplace evidence * Minor edits to Laplace * Remove rmax * Rm old things * Rm comments * Add script * Add super toy selection * Add super toy selection * Update script
This commit is contained in:
parent
d13246a394
commit
3d1e1c0ae3
8 changed files with 243 additions and 57 deletions
|
@ -72,7 +72,7 @@ def print_variables(names, variables):
|
|||
print(flush=True)
|
||||
|
||||
|
||||
def get_models(get_model_kwargs, verbose=True):
|
||||
def get_models(get_model_kwargs, toy_selection, verbose=True):
|
||||
"""Load the data and create the NumPyro models."""
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
||||
|
@ -110,7 +110,8 @@ def get_models(get_model_kwargs, verbose=True):
|
|||
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
|
||||
cat, fpath, paths,
|
||||
ksmooth=ARGS.ksmooth)
|
||||
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
|
||||
models[i] = csiborgtools.flow.get_model(
|
||||
loader, toy_selection=toy_selection[i], **get_model_kwargs)
|
||||
|
||||
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
|
||||
return models
|
||||
|
@ -127,7 +128,7 @@ def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
|||
|
||||
|
||||
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
|
||||
calculate_harmonic, nchains_harmonic, epoch_num, kwargs_print):
|
||||
"""Run the NumPyro model and save output to a file."""
|
||||
try:
|
||||
ndata = sum(model.ndata for model in model_kwargs["models"])
|
||||
|
@ -148,12 +149,12 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
print(f"{'AIC':<20} {AIC}")
|
||||
mcmc.print_summary()
|
||||
|
||||
if calculate_evidence:
|
||||
if calculate_harmonic:
|
||||
print("Calculating the evidence using `harmonic`.", flush=True)
|
||||
neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence(
|
||||
samples, log_posterior, nchains_harmonic, epoch_num)
|
||||
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
|
||||
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
|
||||
print(f"{'-ln(Z_h)':<20} {neg_ln_evidence}")
|
||||
print(f"{'-ln(Z_h) error':<20} {neg_ln_evidence_err}")
|
||||
else:
|
||||
neg_ln_evidence = jax.numpy.nan
|
||||
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
||||
|
@ -180,8 +181,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
grp = f.create_group("gof")
|
||||
grp.create_dataset("BIC", data=BIC)
|
||||
grp.create_dataset("AIC", data=AIC)
|
||||
grp.create_dataset("neg_lnZ", data=neg_ln_evidence)
|
||||
grp.create_dataset("neg_lnZ_err", data=neg_ln_evidence_err)
|
||||
grp.create_dataset("neg_lnZ_harmonic", data=neg_ln_evidence)
|
||||
grp.create_dataset("neg_lnZ_harmonic_err", data=neg_ln_evidence_err)
|
||||
|
||||
fname_summary = fname.replace(".hdf5", ".txt")
|
||||
print(f"Saving summary to `{fname_summary}`.")
|
||||
|
@ -206,7 +207,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
|||
# Command line interface #
|
||||
###############################################################################
|
||||
|
||||
def get_distmod_hyperparams(catalogue, sample_alpha):
|
||||
def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
|
||||
alpha_min = -1.0
|
||||
alpha_max = 3.0
|
||||
|
||||
|
@ -225,7 +226,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
|
|||
"c_mean": 0., "c_std": 20.0,
|
||||
"sample_curvature": False,
|
||||
"a_dipole_mean": 0., "a_dipole_std": 1.0,
|
||||
"sample_a_dipole": True,
|
||||
"sample_a_dipole": sample_mag_dipole,
|
||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||
"sample_alpha": sample_alpha,
|
||||
}
|
||||
|
@ -233,7 +234,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
|
|||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||
"dmu_min": -3.0, "dmu_max": 3.0,
|
||||
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
|
||||
"sample_dmu_dipole": True,
|
||||
"sample_dmu_dipole": sample_mag_dipole,
|
||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||
"sample_alpha": sample_alpha,
|
||||
}
|
||||
|
@ -241,6 +242,16 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
|
|||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||
|
||||
|
||||
def get_toy_selection(toy_selection, catalogue):
|
||||
if not toy_selection:
|
||||
return None
|
||||
|
||||
if catalogue == "SFI_gals":
|
||||
return [1.221e+01, 1.297e+01, -2.708e-01]
|
||||
else:
|
||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
out_folder = "/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity" # noqa
|
||||
|
@ -251,18 +262,23 @@ if __name__ == "__main__":
|
|||
# Fixed user parameters #
|
||||
###########################################################################
|
||||
|
||||
nsteps = 1500
|
||||
nburn = 1000
|
||||
nsteps = 1000
|
||||
nburn = 500
|
||||
zcmb_min = 0
|
||||
zcmb_max = 0.05
|
||||
calculate_evidence = False
|
||||
nchains_harmonic = 10
|
||||
num_epochs = 30
|
||||
inference_method = "mike"
|
||||
num_epochs = 50
|
||||
inference_method = "bayes"
|
||||
calculate_harmonic = True if inference_method == "mike" else False
|
||||
maxmag_selection = None
|
||||
sample_alpha = True
|
||||
sample_alpha = False
|
||||
sample_beta = True
|
||||
sample_Vmono = False
|
||||
sample_mag_dipole = False
|
||||
toy_selection = True
|
||||
|
||||
if toy_selection and inference_method == "mike":
|
||||
raise ValueError("Toy selection is not supported with `mike` inference.") # noqa
|
||||
|
||||
if nsteps % nchains_harmonic != 0:
|
||||
raise ValueError(
|
||||
|
@ -272,10 +288,12 @@ if __name__ == "__main__":
|
|||
"zcmb_min": zcmb_min,
|
||||
"zcmb_max": zcmb_max,
|
||||
"maxmag_selection": maxmag_selection,
|
||||
"calculate_evidence": calculate_evidence,
|
||||
"calculate_harmonic": calculate_harmonic,
|
||||
"nchains_harmonic": nchains_harmonic,
|
||||
"num_epochs": num_epochs,
|
||||
"inference_method": inference_method}
|
||||
"inference_method": inference_method,
|
||||
"sample_mag_dipole": sample_mag_dipole,
|
||||
"toy_selection": toy_selection}
|
||||
print_variables(main_params.keys(), main_params.values())
|
||||
|
||||
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
|
||||
|
@ -290,7 +308,7 @@ if __name__ == "__main__":
|
|||
|
||||
distmod_hyperparams_per_catalogue = []
|
||||
for cat in ARGS.catalogue:
|
||||
x = get_distmod_hyperparams(cat, sample_alpha)
|
||||
x = get_distmod_hyperparams(cat, sample_alpha, sample_mag_dipole)
|
||||
print(f"\n{cat} hyperparameters:")
|
||||
print_variables(x.keys(), x.values())
|
||||
distmod_hyperparams_per_catalogue.append(x)
|
||||
|
@ -301,7 +319,11 @@ if __name__ == "__main__":
|
|||
|
||||
get_model_kwargs = {"zcmb_min": zcmb_min, "zcmb_max": zcmb_max,
|
||||
"maxmag_selection": maxmag_selection}
|
||||
models = get_models(get_model_kwargs, )
|
||||
|
||||
toy_selection = [get_toy_selection(toy_selection, cat)
|
||||
for cat in ARGS.catalogue]
|
||||
|
||||
models = get_models(get_model_kwargs, toy_selection)
|
||||
model_kwargs = {
|
||||
"models": models,
|
||||
"field_calibration_hyperparams": calibration_hyperparams,
|
||||
|
@ -312,5 +334,5 @@ if __name__ == "__main__":
|
|||
model = csiborgtools.flow.PV_validation_model
|
||||
|
||||
run_model(model, nsteps, nburn, model_kwargs, out_folder,
|
||||
calibration_hyperparams["sample_beta"], calculate_evidence,
|
||||
calibration_hyperparams["sample_beta"], calculate_harmonic,
|
||||
nchains_harmonic, num_epochs, kwargs_print)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue