From 3d1e1c0ae39db0a247932d3ff7ab0b5675eb4f42 Mon Sep 17 00:00:00 2001 From: Richard Stiskalek Date: Tue, 27 Aug 2024 00:36:00 +0200 Subject: [PATCH] Add more about evidence and selection to flow (#142) * Add Laplace evidence * Numerically stable laplace evidence * Minor edits to Laplace * Remove rmax * Rm old things * Rm comments * Add script * Add super toy selection * Add super toy selection * Update script --- csiborgtools/flow/__init__.py | 1 + csiborgtools/flow/flow_model.py | 95 +++++++++++++++++++++------------ csiborgtools/flow/selection.py | 69 ++++++++++++++++++++++++ csiborgtools/utils.py | 5 ++ scripts/field_prop/field_los.sh | 2 +- scripts/flow/flow_validation.py | 66 +++++++++++++++-------- scripts/flow/flow_validation.sh | 2 +- scripts/flow/test_harmonic.py | 60 +++++++++++++++++++++ 8 files changed, 243 insertions(+), 57 deletions(-) create mode 100644 csiborgtools/flow/selection.py create mode 100644 scripts/flow/test_harmonic.py diff --git a/csiborgtools/flow/__init__.py b/csiborgtools/flow/__init__.py index 0ee633b..9061201 100644 --- a/csiborgtools/flow/__init__.py +++ b/csiborgtools/flow/__init__.py @@ -17,3 +17,4 @@ from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model, Observed2CosmologicalRedshift, predict_zobs, # noqa project_Vext, radial_velocity_los, # noqa stack_pzosmo_over_realizations) # noqa +from .selection import ToyMagnitudeSelection # noqa diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index 3c48bd9..f015480 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -35,6 +35,7 @@ from numpyro.distributions import Normal, Uniform, MultivariateNormal from quadax import simpson from tqdm import trange +from .selection import toy_log_magnitude_selection from ..params import SPEED_OF_LIGHT, simname2Omega_m from ..utils import fprint, radec_to_galactic, radec_to_supergalactic @@ -78,7 +79,7 @@ class DataLoader: self._catname = catalogue fprint("reading the interpolated field.", verbose) - self._field_rdist, self._los_density, self._los_velocity, self._rmax = self._read_field( # noqa + self._field_rdist, self._los_density, self._los_velocity = self._read_field( # noqa simname, ksim, catalogue, ksmooth, paths) if len(self._cat) != self._los_density.shape[1]: @@ -169,14 +170,6 @@ class DataLoader: return self._los_velocity[:, :, self._mask, ...] - @property - def rmax(self): - """ - Radial distance above which the underlying reconstruction is - extrapolated `(n_sims, n_objects)`. - """ - return self._rmax[:, self._mask] - @property def los_radial_velocity(self): """ @@ -201,7 +194,6 @@ class DataLoader: los_density = [None] * len(ksims) los_velocity = [None] * len(ksims) - rmax = [None] * len(ksims) for n, ksim in enumerate(ksims): nsim = nsims[ksim] @@ -216,13 +208,11 @@ class DataLoader: los_density[n] = f[f"density_{nsim}"][indx] los_velocity[n] = f[f"velocity_{nsim}"][indx] rdist = f[f"rdist_{nsim}"][...] - rmax[n] = f[f"rmax_{nsim}"][indx] los_density = np.stack(los_density) los_velocity = np.stack(los_velocity) - rmax = np.stack(rmax) - return rdist, los_density, los_velocity, rmax + return rdist, los_density, los_velocity def _read_catalogue(self, catalogue, catalogue_fpath): if catalogue == "A2": @@ -622,9 +612,6 @@ class PV_LogLikelihood(BaseFlowValidationModel): LOS density field. los_velocity : 3-dimensional array of shape (n_sims, n_objects, n_steps) LOS radial velocity field. - rmax : 1-dimensional array of shape (n_sims, n_objects) - Radial distance above which the underlying reconstruction is - extrapolated. RA, dec : 1-dimensional arrays of shape (n_objects) Right ascension and declination in degrees. z_obs : 1-dimensional array of shape (n_objects) @@ -643,11 +630,13 @@ class PV_LogLikelihood(BaseFlowValidationModel): Catalogue kind, either "TFR", "SN", or "simple". name : str Name of the catalogue. + toy_selection : tuple of length 3, optional + Toy magnitude selection paramers `m1`, `m2` and `a`. Optional. """ - def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs, - e_zobs, calibration_params, maxmag_selection, r_xrange, - Omega_m, kind, name): + def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs, + calibration_params, maxmag_selection, r_xrange, Omega_m, + kind, name, toy_selection=None): if e_zobs is not None: e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) else: @@ -657,9 +646,9 @@ class PV_LogLikelihood(BaseFlowValidationModel): RA = np.deg2rad(RA) dec = np.deg2rad(dec) - names = ["los_density", "los_velocity", "rmax", "RA", "dec", "z_obs", + names = ["los_density", "los_velocity", "RA", "dec", "z_obs", "e2_cz_obs"] - values = [los_density, los_velocity, rmax, RA, dec, z_obs, e2_cz_obs] + values = [los_density, los_velocity, RA, dec, z_obs, e2_cz_obs] self._setattr_as_jax(names, values) self._set_calibration_params(calibration_params) self._set_radial_spacing(r_xrange, Omega_m) @@ -669,6 +658,7 @@ class PV_LogLikelihood(BaseFlowValidationModel): self.Omega_m = Omega_m self.norm = - self.ndata * jnp.log(self.num_sims) self.maxmag_selection = maxmag_selection + self.toy_selection = toy_selection if kind == "TFR": self.mag_min, self.mag_max = jnp.min(self.mag), jnp.max(self.mag) @@ -688,6 +678,17 @@ class PV_LogLikelihood(BaseFlowValidationModel): if maxmag_selection is not None and self.maxmag_selection > self.mag_max: # noqa raise ValueError("The maximum magnitude cannot be larger than the selection threshold.") # noqa + if toy_selection is not None and self.maxmag_selection is not None: + raise ValueError("`toy_selection` and `maxmag_selection` cannot be used together.") # noqa + + if toy_selection is not None: + self.m1, self.m2, self.a = toy_selection + self.log_Fm = toy_log_magnitude_selection( + self.mag, self.m1, self.m2, self.a) + + if toy_selection is not None and self.kind != "TFR": + raise ValueError("Toy selection is only implemented for TFRs.") + def __call__(self, field_calibration_params, distmod_params, inference_method): if inference_method not in ["mike", "bayes"]: @@ -772,12 +773,30 @@ class PV_LogLikelihood(BaseFlowValidationModel): mag_true, eta_true = x_true[..., 0], x_true[..., 1] # Log-likelihood of the observed magnitudes. - if self.maxmag_selection is None: - ll0 += jnp.sum(normal_logpdf( - self.mag, mag_true, self.e_mag)) - else: + if self.maxmag_selection is not None: ll0 += jnp.sum(upper_truncated_normal_logpdf( self.mag, mag_true, self.e_mag, self.maxmag_selection)) + elif self.toy_selection is not None: + ll_mag = self.log_Fm + ll_mag += normal_logpdf(self.mag, mag_true, self.e_mag) + + # Normalization per datapoint, initially (ndata, nxrange) + mu_start = mag_true - 5 * self.e_mag + mu_end = mag_true + 5 * self.e_mag + # 100 is a reasonable and sufficient choice. + mu_xrange = jnp.linspace(mu_start, mu_end, 100).T + + norm = toy_log_magnitude_selection( + mu_xrange, self.m1, self.m2, self.a) + norm = norm + normal_logpdf( + mu_xrange, mag_true[:, None], self.e_mag[:, None]) + # Now integrate over the magnitude range. + norm = simpson(jnp.exp(norm), x=mu_xrange, axis=-1) + + ll0 += jnp.sum(ll_mag - jnp.log(norm)) + else: + ll0 += jnp.sum(normal_logpdf( + self.mag, mag_true, self.e_mag)) # Log-likelihood of the observed linewidths. ll0 += jnp.sum(normal_logpdf(eta_true, self.eta, self.e_eta)) @@ -876,7 +895,8 @@ def PV_validation_model(models, distmod_hyperparams_per_model, ############################################################################### -def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None): +def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None, + toy_selection=None): """ Get a model and extract the relevant data from the loader. @@ -890,6 +910,9 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None): Maximum observed redshift in the CMB frame to include. maxmag_selection : float, optional Maximum magnitude selection threshold. + toy_selection : tuple of length 3, optional + Toy magnitude selection paramers `m1`, `m2` and `a` for TFRs of the + Boubel+24 model. Returns ------- @@ -899,7 +922,6 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None): los_overdensity = loader.los_density los_velocity = loader.los_radial_velocity - rmax = loader.rmax kind = loader._catname if maxmag_selection is not None and kind != "2MTF": @@ -917,7 +939,7 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None): "e_c": e_c[mask]} model = PV_LogLikelihood( - los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], + los_overdensity[:, mask], los_velocity[:, mask], RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params, maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind) elif "Pantheon+" in kind: @@ -945,20 +967,27 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None): "e_mag": e_mB[mask], "e_x1": e_x1[mask], "e_c": e_c[mask]} model = PV_LogLikelihood( - los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], + los_overdensity[:, mask], los_velocity[:, mask], RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params, maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind) elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]: keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"] RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys) + if kind == "SFI_gals" and toy_selection is not None: + if len(toy_selection) != 3: + raise ValueError("Toy selection must be a tuple with 3 elements.") # noqa + m1, m2, a = toy_selection + fprint(f"using toy selection with m1 = {m1}, m2 = {m2}, a = {a}.") + mask = (zCMB < zcmb_max) & (zCMB > zcmb_min) calibration_params = {"mag": mag[mask], "eta": eta[mask], "e_mag": e_mag[mask], "e_eta": e_eta[mask]} model = PV_LogLikelihood( - los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], + los_overdensity[:, mask], los_velocity[:, mask], RA[mask], dec[mask], zCMB[mask], None, calibration_params, - maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind) + maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind, + toy_selection=toy_selection) elif "CF4_TFR_" in kind: # The full name can be e.g. "CF4_TFR_not2MTForSFI_i" or "CF4_TFR_i". band = kind.split("_")[-1] @@ -995,7 +1024,7 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None): calibration_params = {"mag": mag[mask], "eta": eta[mask], "e_mag": e_mag[mask], "e_eta": e_eta[mask]} model = PV_LogLikelihood( - los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], + los_overdensity[:, mask], los_velocity[:, mask], RA[mask], dec[mask], z_obs[mask], None, calibration_params, maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind) elif kind in ["CF4_GroupAll"]: @@ -1011,7 +1040,7 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None): calibration_params = {"mu": mu[mask], "e_mu": e_mu[mask]} model = PV_LogLikelihood( - los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask], + los_overdensity[:, mask], los_velocity[:, mask], RA[mask], dec[mask], zCMB[mask], None, calibration_params, maxmag_selection, loader.rdist, loader._Omega_m, "simple", name=kind) diff --git a/csiborgtools/flow/selection.py b/csiborgtools/flow/selection.py new file mode 100644 index 0000000..9fb3a58 --- /dev/null +++ b/csiborgtools/flow/selection.py @@ -0,0 +1,69 @@ +# Copyright (C) 2024 Richard Stiskalek +# This program is free software; you can redistribute it and/or modify it +# under the terms of the GNU General Public License as published by the +# Free Software Foundation; either version 3 of the License, or (at your +# option) any later version. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General +# Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +"""Selection functions for peculiar velocities.""" +import numpy as np +from jax import numpy as jnp +from scipy.integrate import quad +from scipy.optimize import minimize + + +class ToyMagnitudeSelection: + """ + Toy magnitude selection according to Boubel et al 2024. + """ + + def __init__(self): + pass + + def log_true_pdf(self, m, m1): + """Unnormalized `true' PDF.""" + return 0.6 * (m - m1) + + def log_selection_function(self, m, m1, m2, a): + return np.where(m <= m1, + 0, + a * (m - m2)**2 - a * (m1 - m2)**2 - 0.6 * (m - m1)) + + def log_observed_pdf(self, m, m1, m2, a): + # Calculate the normalization constant + f = lambda m: 10**(self.log_true_pdf(m, m1) # noqa + + self.log_selection_function(m, m1, m2, a)) + mmin, mmax = 0, 25 + norm = quad(f, mmin, mmax)[0] + + return (self.log_true_pdf(m, m1) + + self.log_selection_function(m, m1, m2, a) + - np.log10(norm)) + + def fit(self, mag): + + def loss(x): + m1, m2, a = x + + if a >= 0: + return np.inf + + return -np.sum(self.log_observed_pdf(mag, m1, m2, a)) + + x0 = [12.0, 12.5, -0.1] + return minimize(loss, x0, method="Nelder-Mead") + + +def toy_log_magnitude_selection(mag, m1, m2, a): + """JAX implementation of `ToyMagnitudeSelection` but natural logarithm.""" + return jnp.log(10) * jnp.where( + mag <= m1, + 0, + a * (mag - m2)**2 - a * (m1 - m2)**2 - 0.6 * (mag - m1)) diff --git a/csiborgtools/utils.py b/csiborgtools/utils.py index 2d129db..e62bc02 100644 --- a/csiborgtools/utils.py +++ b/csiborgtools/utils.py @@ -492,6 +492,11 @@ def dict_samples_to_array(samples): for i in range(value.shape[-1]): data.append(value[:, i]) names.append(f"{key}_{i}") + elif value.ndim == 3: + for i in range(value.shape[-1]): + for j in range(value.shape[-2]): + data.append(value[:, j, i]) + names.append(f"{key}_{j}_{i}") else: raise ValueError("Invalid dimensionality of samples to stack.") diff --git a/scripts/field_prop/field_los.sh b/scripts/field_prop/field_los.sh index 2fe3037..8ec545b 100755 --- a/scripts/field_prop/field_los.sh +++ b/scripts/field_prop/field_los.sh @@ -10,7 +10,7 @@ MAS="SPH" grid=1024 -for simname in "Carrick2015"; do +for simname in "Lilow2024"; do for catalogue in "CF4_TFR"; do pythoncm="$env $file --catalogue $catalogue --nsims $nsims --simname $simname --MAS $MAS --grid $grid" if [ $on_login -eq 1 ]; then diff --git a/scripts/flow/flow_validation.py b/scripts/flow/flow_validation.py index c159123..ec49a1f 100644 --- a/scripts/flow/flow_validation.py +++ b/scripts/flow/flow_validation.py @@ -72,7 +72,7 @@ def print_variables(names, variables): print(flush=True) -def get_models(get_model_kwargs, verbose=True): +def get_models(get_model_kwargs, toy_selection, verbose=True): """Load the data and create the NumPyro models.""" paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) folder = "/mnt/extraspace/rstiskalek/catalogs/" @@ -110,7 +110,8 @@ def get_models(get_model_kwargs, verbose=True): loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator, cat, fpath, paths, ksmooth=ARGS.ksmooth) - models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs) + models[i] = csiborgtools.flow.get_model( + loader, toy_selection=toy_selection[i], **get_model_kwargs) print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True) return models @@ -127,7 +128,7 @@ def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num): def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta, - calculate_evidence, nchains_harmonic, epoch_num, kwargs_print): + calculate_harmonic, nchains_harmonic, epoch_num, kwargs_print): """Run the NumPyro model and save output to a file.""" try: ndata = sum(model.ndata for model in model_kwargs["models"]) @@ -148,12 +149,12 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta, print(f"{'AIC':<20} {AIC}") mcmc.print_summary() - if calculate_evidence: + if calculate_harmonic: print("Calculating the evidence using `harmonic`.", flush=True) neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence( samples, log_posterior, nchains_harmonic, epoch_num) - print(f"{'-ln(Z)':<20} {neg_ln_evidence}") - print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}") + print(f"{'-ln(Z_h)':<20} {neg_ln_evidence}") + print(f"{'-ln(Z_h) error':<20} {neg_ln_evidence_err}") else: neg_ln_evidence = jax.numpy.nan neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan) @@ -180,8 +181,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta, grp = f.create_group("gof") grp.create_dataset("BIC", data=BIC) grp.create_dataset("AIC", data=AIC) - grp.create_dataset("neg_lnZ", data=neg_ln_evidence) - grp.create_dataset("neg_lnZ_err", data=neg_ln_evidence_err) + grp.create_dataset("neg_lnZ_harmonic", data=neg_ln_evidence) + grp.create_dataset("neg_lnZ_harmonic_err", data=neg_ln_evidence_err) fname_summary = fname.replace(".hdf5", ".txt") print(f"Saving summary to `{fname_summary}`.") @@ -206,7 +207,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta, # Command line interface # ############################################################################### -def get_distmod_hyperparams(catalogue, sample_alpha): +def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole): alpha_min = -1.0 alpha_max = 3.0 @@ -225,7 +226,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha): "c_mean": 0., "c_std": 20.0, "sample_curvature": False, "a_dipole_mean": 0., "a_dipole_std": 1.0, - "sample_a_dipole": True, + "sample_a_dipole": sample_mag_dipole, "alpha_min": alpha_min, "alpha_max": alpha_max, "sample_alpha": sample_alpha, } @@ -233,7 +234,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha): return {"e_mu_min": 0.001, "e_mu_max": 1.0, "dmu_min": -3.0, "dmu_max": 3.0, "dmu_dipole_mean": 0., "dmu_dipole_std": 1.0, - "sample_dmu_dipole": True, + "sample_dmu_dipole": sample_mag_dipole, "alpha_min": alpha_min, "alpha_max": alpha_max, "sample_alpha": sample_alpha, } @@ -241,6 +242,16 @@ def get_distmod_hyperparams(catalogue, sample_alpha): raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") +def get_toy_selection(toy_selection, catalogue): + if not toy_selection: + return None + + if catalogue == "SFI_gals": + return [1.221e+01, 1.297e+01, -2.708e-01] + else: + raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.") + + if __name__ == "__main__": paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) out_folder = "/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity" # noqa @@ -251,18 +262,23 @@ if __name__ == "__main__": # Fixed user parameters # ########################################################################### - nsteps = 1500 - nburn = 1000 + nsteps = 1000 + nburn = 500 zcmb_min = 0 zcmb_max = 0.05 - calculate_evidence = False nchains_harmonic = 10 - num_epochs = 30 - inference_method = "mike" + num_epochs = 50 + inference_method = "bayes" + calculate_harmonic = True if inference_method == "mike" else False maxmag_selection = None - sample_alpha = True + sample_alpha = False sample_beta = True sample_Vmono = False + sample_mag_dipole = False + toy_selection = True + + if toy_selection and inference_method == "mike": + raise ValueError("Toy selection is not supported with `mike` inference.") # noqa if nsteps % nchains_harmonic != 0: raise ValueError( @@ -272,10 +288,12 @@ if __name__ == "__main__": "zcmb_min": zcmb_min, "zcmb_max": zcmb_max, "maxmag_selection": maxmag_selection, - "calculate_evidence": calculate_evidence, + "calculate_harmonic": calculate_harmonic, "nchains_harmonic": nchains_harmonic, "num_epochs": num_epochs, - "inference_method": inference_method} + "inference_method": inference_method, + "sample_mag_dipole": sample_mag_dipole, + "toy_selection": toy_selection} print_variables(main_params.keys(), main_params.values()) calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000, @@ -290,7 +308,7 @@ if __name__ == "__main__": distmod_hyperparams_per_catalogue = [] for cat in ARGS.catalogue: - x = get_distmod_hyperparams(cat, sample_alpha) + x = get_distmod_hyperparams(cat, sample_alpha, sample_mag_dipole) print(f"\n{cat} hyperparameters:") print_variables(x.keys(), x.values()) distmod_hyperparams_per_catalogue.append(x) @@ -301,7 +319,11 @@ if __name__ == "__main__": get_model_kwargs = {"zcmb_min": zcmb_min, "zcmb_max": zcmb_max, "maxmag_selection": maxmag_selection} - models = get_models(get_model_kwargs, ) + + toy_selection = [get_toy_selection(toy_selection, cat) + for cat in ARGS.catalogue] + + models = get_models(get_model_kwargs, toy_selection) model_kwargs = { "models": models, "field_calibration_hyperparams": calibration_hyperparams, @@ -312,5 +334,5 @@ if __name__ == "__main__": model = csiborgtools.flow.PV_validation_model run_model(model, nsteps, nburn, model_kwargs, out_folder, - calibration_hyperparams["sample_beta"], calculate_evidence, + calibration_hyperparams["sample_beta"], calculate_harmonic, nchains_harmonic, num_epochs, kwargs_print) diff --git a/scripts/flow/flow_validation.sh b/scripts/flow/flow_validation.sh index 366bd3f..50f7b16 100755 --- a/scripts/flow/flow_validation.sh +++ b/scripts/flow/flow_validation.sh @@ -39,7 +39,7 @@ fi # for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do for simname in "Carrick2015"; do - for catalogue in "CF4_GroupAll"; do + for catalogue in "SFI_gals"; do # for catalogue in "CF4_TFR_i"; do # for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do for ksim in "none"; do diff --git a/scripts/flow/test_harmonic.py b/scripts/flow/test_harmonic.py new file mode 100644 index 0000000..69c6a4f --- /dev/null +++ b/scripts/flow/test_harmonic.py @@ -0,0 +1,60 @@ +from argparse import ArgumentParser, ArgumentTypeError + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("--device", type=str, default="cpu", + help="Device to use.") + return parser.parse_args() + + +ARGS = parse_args() +# This must be done before we import JAX etc. +from numpyro import set_host_device_count, set_platform # noqa + +set_platform(ARGS.device) # noqa + +from jax import numpy as jnp # noqa +import numpy as np # noqa +import csiborgtools # noqa +from scipy.stats import multivariate_normal # noqa + + +def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num): + """Compute evidence using the `harmonic` package.""" + data, names = csiborgtools.dict_samples_to_array(samples) + data = data.reshape(nchains_harmonic, -1, len(names)) + log_posterior = log_posterior.reshape(nchains_harmonic, -1) + + return csiborgtools.harmonic_evidence( + data, log_posterior, return_flow_samples=False, epochs_num=epoch_num) + + +ndim = 250 +nsamples = 100_000 +nchains_split = 10 +loc = jnp.zeros(ndim) +cov = jnp.eye(ndim) + + +gen = np.random.default_rng() +X = gen.multivariate_normal(loc, cov, size=nsamples) +samples = {f"x_{i}": X[:, i] for i in range(ndim)} +logprob = multivariate_normal(loc, cov).logpdf(X) + +neg_lnZ_laplace, neg_lnZ_laplace_error = csiborgtools.laplace_evidence( + samples, logprob, nchains_split) +print(f"neg_lnZ_laplace: {neg_lnZ_laplace} +/- {neg_lnZ_laplace_error}") + + +neg_lnZ_harmonic, neg_lnZ_harmonic_error = get_harmonic_evidence( + samples, logprob, nchains_split, epoch_num=30) +print(f"neg_lnZ_harmonic: {neg_lnZ_harmonic} +/- {neg_lnZ_harmonic_error}") + + + + + + + +