# Copyright (C) 2024 Richard Stiskalek # This program is free software; you can redistribute it and/or modify it # under the terms of the GNU General Public License as published by the # Free Software Foundation; either version 3 of the License, or (at your # option) any later version. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General # Public License for more details. # # You should have received a copy of the GNU General Public License along # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """ Script to run the PV validation model on various catalogues and simulations. The script is not MPI parallelised, instead it is best run on a GPU. """ from argparse import ArgumentParser, ArgumentTypeError def none_or_int(value): if value.lower() == "none": return None if "_" in value: args = value.split("_") if len(args) == 2: k0, kf = args dk = 1 elif len(args) == 3: k0, kf, dk = args else: raise ArgumentTypeError(f"Invalid length of arguments: `{value}`.") return [int(k) for k in range(int(k0), int(kf), int(dk))] try: return int(value) except ValueError: raise ArgumentTypeError(f"Invalid value: {value}. Must be an integer or 'none'.") # noqa def parse_args(): parser = ArgumentParser() parser.add_argument("--simname", type=str, required=True, help="Simulation name.") parser.add_argument("--catalogue", type=str, required=True, help="PV catalogues.") parser.add_argument("--ksmooth", type=int, default=0, help="Smoothing index.") parser.add_argument("--ksim", type=none_or_int, default=None, help="IC iteration number. If 'None', all IC realizations are used.") # noqa parser.add_argument("--ndevice", type=int, default=1, help="Number of devices to request.") parser.add_argument("--device", type=str, default="cpu", help="Device to use.") args = parser.parse_args() # Convert the catalogue to a list of catalogues args.catalogue = args.catalogue.split(",") return args ARGS = parse_args() # This must be done before we import JAX etc. from numpyro import set_host_device_count, set_platform # noqa set_platform(ARGS.device) # noqa set_host_device_count(ARGS.ndevice) # noqa import sys # noqa from os.path import join # noqa import csiborgtools # noqa import jax # noqa import numpy as np # noqa from csiborgtools import fprint # noqa from h5py import File # noqa from numpyro.infer import MCMC, NUTS, init_to_median # noqa def print_variables(names, variables): for name, variable in zip(names, variables): print(f"{name:<20} {variable}", flush=True) print(flush=True) def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs, verbose=True): """Load the data and create the NumPyro models.""" paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) folder = "/mnt/extraspace/rstiskalek/catalogs/" nsims = paths.get_ics(ARGS.simname) if ksim is None: nsim_iterator = [i for i in range(len(nsims))] else: nsim_iterator = [ksim] nsims = [nsims[ksim]] if verbose: print(f"{'Simulation:':<20} {ARGS.simname}") print(f"{'Catalogue:':<20} {ARGS.catalogue}") print(f"{'Num. realisations:':<20} {len(nsims)}") print(flush=True) # Get models models = [None] * len(ARGS.catalogue) for i, cat in enumerate(ARGS.catalogue): if cat == "A2": fpath = join(folder, "A2.h5") elif cat in ["LOSS", "Foundation", "Pantheon+", "SFI_gals", "2MTF", "SFI_groups", "SFI_gals_masked", "Pantheon+_groups", "Pantheon+_groups_zSN", "Pantheon+_zSN"]: fpath = join(folder, "PV_compilation.hdf5") elif "CF4_TFR" in cat: fpath = join(folder, "PV/CF4/CF4_TF-distances.hdf5") elif cat in ["CF4_GroupAll"]: fpath = join(folder, "PV/CF4/CF4_GroupAll.hdf5") else: raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator, cat, fpath, paths, ksmooth=ARGS.ksmooth) models[i] = csiborgtools.flow.get_model( loader, mag_selection=mag_selection[i], void_kwargs=void_kwargs, **get_model_kwargs) fprint(f"num. radial steps is {len(loader.rdist)}") return models def select_void_h(kind): hs = {"mb": 0.7615, "gauss": 0.7724, "exp": 0.7725} try: return hs[kind] except KeyError: raise ValueError(f"Unknown void kind: `{kind}`.") def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num): """Compute evidence using the `harmonic` package.""" data, names = csiborgtools.dict_samples_to_array(samples) data = data.reshape(nchains_harmonic, -1, len(names)) log_posterior = log_posterior.reshape(nchains_harmonic, -1) return csiborgtools.harmonic_evidence( data, log_posterior, return_flow_samples=False, epochs_num=epoch_num) def run_model(model, nsteps, nburn, model_kwargs, out_folder, calculate_harmonic, nchains_harmonic, epoch_num, kwargs_print, fname_kwargs): """Run the NumPyro model and save output to a file.""" paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) fname = paths.flow_validation(out_folder, ARGS.simname, ARGS.catalogue, **fname_kwargs) try: ndata = sum(model.ndata for model in model_kwargs["models"]) except AttributeError as e: raise AttributeError("The models must have an attribute `ndata` " "indicating the number of data points.") from e nuts_kernel = NUTS(model, init_strategy=init_to_median(num_samples=10000)) mcmc = MCMC(nuts_kernel, num_warmup=nburn, num_samples=nsteps) rng_key = jax.random.PRNGKey(42) mcmc.run(rng_key, extra_fields=("potential_energy",), **model_kwargs) samples = mcmc.get_samples() log_posterior = -mcmc.get_extra_fields()["potential_energy"] BIC, AIC = csiborgtools.BIC_AIC(samples, log_posterior, ndata) print(f"{'BIC':<20} {BIC}") print(f"{'AIC':<20} {AIC}") mcmc.print_summary() if calculate_harmonic: print("Calculating the evidence using `harmonic`.", flush=True) neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence( samples, log_posterior, nchains_harmonic, epoch_num) print(f"{'-ln(Z_h)':<20} {neg_ln_evidence}") print(f"{'-ln(Z_h) error':<20} {neg_ln_evidence_err}") else: neg_ln_evidence = jax.numpy.nan neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan) fname = join(out_folder, fname) print(f"Saving results to `{fname}`.") with File(fname, "w") as f: # Write samples grp = f.create_group("samples") for key, value in samples.items(): grp.create_dataset(key, data=value) # Write log likelihood and posterior f.create_dataset("log_posterior", data=log_posterior) # Write goodness of fit grp = f.create_group("gof") grp.create_dataset("BIC", data=BIC) grp.create_dataset("AIC", data=AIC) grp.create_dataset("neg_lnZ_harmonic", data=neg_ln_evidence) grp.create_dataset("neg_lnZ_harmonic_err", data=neg_ln_evidence_err) fname_summary = fname.replace(".hdf5", ".txt") print(f"Saving summary to `{fname_summary}`.") with open(fname_summary, 'w') as f: original_stdout = sys.stdout sys.stdout = f print("User parameters:") for kwargs in kwargs_print: print_variables(kwargs.keys(), kwargs.values()) print("HMC summary:") print(f"{'BIC':<20} {BIC}") print(f"{'AIC':<20} {AIC}") print(f"{'-ln(Z)':<20} {neg_ln_evidence}") print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}") mcmc.print_summary(exclude_deterministic=False) sys.stdout = original_stdout ############################################################################### # Command line interface # ############################################################################### def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole): alpha_min = -1.0 alpha_max = 10.0 if catalogue in ["LOSS", "Foundation", "Pantheon+", "Pantheon+_groups", "Pantheon+_zSN"]: # noqa return {"e_mu_min": 0.001, "e_mu_max": 1.0, "mag_cal_mean": -18.25, "mag_cal_std": 2.0, "alpha_cal_mean": 0.148, "alpha_cal_std": 1.0, "beta_cal_mean": 3.112, "beta_cal_std": 2.0, "alpha_min": alpha_min, "alpha_max": alpha_max, "sample_alpha": sample_alpha } elif catalogue in ["SFI_gals", "2MTF"] or "CF4_TFR" in catalogue: return {"e_mu_min": 0.001, "e_mu_max": 1.0, "a_mean": -21., "a_std": 5.0, "b_mean": -5.95, "b_std": 4.0, "c_mean": 0., "c_std": 20.0, "a_dipole_mean": 0., "a_dipole_std": 1.0, "sample_a_dipole": sample_mag_dipole, "alpha_min": alpha_min, "alpha_max": alpha_max, "sample_alpha": sample_alpha, } elif catalogue in ["CF4_GroupAll"]: return {"e_mu_min": 0.001, "e_mu_max": 1.0, "dmu_min": -3.0, "dmu_max": 3.0, "dmu_dipole_mean": 0., "dmu_dipole_std": 1.0, "sample_dmu_dipole": sample_mag_dipole, "alpha_min": alpha_min, "alpha_max": alpha_max, "sample_alpha": sample_alpha, } else: raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") def get_toy_selection(catalogue): """Toy magnitude selection coefficients.""" if catalogue == "SFI_gals": kind = "soft" # m1, m2, a coeffs = [11.467, 12.906, -0.231] elif "CF4_TFR" in catalogue and "_i" in catalogue: kind = "soft" coeffs = [13.043, 14.423, -0.129] elif "CF4_TFR" in catalogue and "w1" in catalogue: kind = "soft" coeffs = [11.731, 14.189, -0.118] elif catalogue == "2MTF": kind = "hard" coeffs = 11.25 else: fprint(f"found no selection coefficients for {catalogue}.") return None return {"kind": kind, "coeffs": coeffs} if __name__ == "__main__": paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) out_folder = "/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity" # noqa print(f"{'Num. devices:':<20} {jax.device_count()}") print(f"{'Devices:':<20} {jax.devices()}") ########################################################################### # Fixed user parameters # ########################################################################### # `None` means default behaviour nsteps = 10_000 nburn = 2_000 zcmb_min = None zcmb_max = 0.05 nchains_harmonic = 10 num_epochs = 50 inference_method = "mike" mag_selection = None sample_alpha = False if "IndranilVoid_" in ARGS.simname or ARGS.simname == "no_field" else True # noqa sample_beta = None no_Vext = None sample_Vmag_vax = False sample_Vmono = False sample_mag_dipole = False absolute_calibration = None calculate_harmonic = False if inference_method == "bayes" else True sample_h = True if absolute_calibration is not None else False fname_kwargs = {"inference_method": inference_method, "smooth": ARGS.ksmooth, "nsim": ARGS.ksim, "zcmb_min": zcmb_min, "zcmb_max": zcmb_max, "mag_selection": mag_selection, "sample_alpha": sample_alpha, "sample_beta": sample_beta, "no_Vext": no_Vext, "sample_Vmag_vax": sample_Vmag_vax, "sample_Vmono": sample_Vmono, "sample_mag_dipole": sample_mag_dipole, "absolute_calibration": absolute_calibration, } main_params = {"nsteps": nsteps, "nburn": nburn, "zcmb_min": zcmb_min, "zcmb_max": zcmb_max, "mag_selection": mag_selection, "calculate_harmonic": calculate_harmonic, "nchains_harmonic": nchains_harmonic, "num_epochs": num_epochs, "inference_method": inference_method, "sample_mag_dipole": sample_mag_dipole, "absolute_calibration": absolute_calibration, "sample_h": sample_h, } print_variables(main_params.keys(), main_params.values()) if sample_beta is None: sample_beta = ARGS.simname == "Carrick2015" if mag_selection and inference_method != "bayes": raise ValueError("Magnitude selection is only supported with `bayes` inference.") # noqa if "IndranilVoid" in ARGS.simname: if ARGS.ksim is not None: raise ValueError( "`IndranilVoid` does not have multiple realisations.") kind = ARGS.simname.split("_")[-1] h = select_void_h(kind) rdist = np.arange(0, 165, 0.5) void_kwargs = {"kind": kind, "h": h, "order": 1, "rdist": rdist} else: void_kwargs = None h = 1. if inference_method != "bayes": mag_selection = [None] * len(ARGS.catalogue) elif mag_selection is None or mag_selection: mag_selection = [get_toy_selection(cat) for cat in ARGS.catalogue] if nsteps % nchains_harmonic != 0: raise ValueError( "The number of steps must be divisible by the number of chains.") calibration_hyperparams = {"Vext_min": -3000, "Vext_max": 3000, "Vmono_min": -1000, "Vmono_max": 1000, "beta_min": -10.0, "beta_max": 10.0, "sigma_v_min": 1.0, "sigma_v_max": 5000 if "IndranilVoid_" in ARGS.simname else 750., # noqa "h_min": 0.01, "h_max": 5.0, "no_Vext": False if no_Vext is None else no_Vext, # noqa "sample_Vmag_vax": sample_Vmag_vax, "sample_Vmono": sample_Vmono, "sample_beta": sample_beta, "sample_h": sample_h, "sample_rLG": "IndranilVoid" in ARGS.simname, "rLG_min": 0.0, "rLG_max": 500 * h, } print_variables( calibration_hyperparams.keys(), calibration_hyperparams.values()) distmod_hyperparams_per_catalogue = [] for cat in ARGS.catalogue: x = get_distmod_hyperparams(cat, sample_alpha, sample_mag_dipole) print(f"\n{cat} hyperparameters:") print_variables(x.keys(), x.values()) distmod_hyperparams_per_catalogue.append(x) kwargs_print = (main_params, calibration_hyperparams, *distmod_hyperparams_per_catalogue) ########################################################################### get_model_kwargs = { "zcmb_min": zcmb_min, "zcmb_max": zcmb_max, "absolute_calibration": absolute_calibration, "calibration_fpath": "/mnt/extraspace/rstiskalek/catalogs/PV/CF4/CF4_TF_calibration.hdf5", # noqa } # In case we want to run multiple simulations independently. if not isinstance(ARGS.ksim, list): ksim_iterator = [ARGS.ksim] else: ksim_iterator = ARGS.ksim for i, ksim in enumerate(ksim_iterator): if len(ksim_iterator) > 1: print(f"{'Current simulation:':<20} {i + 1} ({ksim}) out of {len(ksim_iterator)}.") # noqa fname_kwargs["nsim"] = ksim models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs) model_kwargs = { "models": models, "field_calibration_hyperparams": calibration_hyperparams, "distmod_hyperparams_per_model": distmod_hyperparams_per_catalogue, "inference_method": inference_method, } model = csiborgtools.flow.PV_validation_model run_model(model, nsteps, nburn, model_kwargs, out_folder, calculate_harmonic, nchains_harmonic, num_epochs, kwargs_print, fname_kwargs)