mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 12:18:01 +00:00
Add more about evidence and selection to flow (#142)
* Add Laplace evidence * Numerically stable laplace evidence * Minor edits to Laplace * Remove rmax * Rm old things * Rm comments * Add script * Add super toy selection * Add super toy selection * Update script
This commit is contained in:
parent
d13246a394
commit
3d1e1c0ae3
8 changed files with 243 additions and 57 deletions
|
@ -17,3 +17,4 @@ from .flow_model import (DataLoader, PV_LogLikelihood, PV_validation_model,
|
||||||
Observed2CosmologicalRedshift, predict_zobs, # noqa
|
Observed2CosmologicalRedshift, predict_zobs, # noqa
|
||||||
project_Vext, radial_velocity_los, # noqa
|
project_Vext, radial_velocity_los, # noqa
|
||||||
stack_pzosmo_over_realizations) # noqa
|
stack_pzosmo_over_realizations) # noqa
|
||||||
|
from .selection import ToyMagnitudeSelection # noqa
|
||||||
|
|
|
@ -35,6 +35,7 @@ from numpyro.distributions import Normal, Uniform, MultivariateNormal
|
||||||
from quadax import simpson
|
from quadax import simpson
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
|
from .selection import toy_log_magnitude_selection
|
||||||
from ..params import SPEED_OF_LIGHT, simname2Omega_m
|
from ..params import SPEED_OF_LIGHT, simname2Omega_m
|
||||||
from ..utils import fprint, radec_to_galactic, radec_to_supergalactic
|
from ..utils import fprint, radec_to_galactic, radec_to_supergalactic
|
||||||
|
|
||||||
|
@ -78,7 +79,7 @@ class DataLoader:
|
||||||
self._catname = catalogue
|
self._catname = catalogue
|
||||||
|
|
||||||
fprint("reading the interpolated field.", verbose)
|
fprint("reading the interpolated field.", verbose)
|
||||||
self._field_rdist, self._los_density, self._los_velocity, self._rmax = self._read_field( # noqa
|
self._field_rdist, self._los_density, self._los_velocity = self._read_field( # noqa
|
||||||
simname, ksim, catalogue, ksmooth, paths)
|
simname, ksim, catalogue, ksmooth, paths)
|
||||||
|
|
||||||
if len(self._cat) != self._los_density.shape[1]:
|
if len(self._cat) != self._los_density.shape[1]:
|
||||||
|
@ -169,14 +170,6 @@ class DataLoader:
|
||||||
|
|
||||||
return self._los_velocity[:, :, self._mask, ...]
|
return self._los_velocity[:, :, self._mask, ...]
|
||||||
|
|
||||||
@property
|
|
||||||
def rmax(self):
|
|
||||||
"""
|
|
||||||
Radial distance above which the underlying reconstruction is
|
|
||||||
extrapolated `(n_sims, n_objects)`.
|
|
||||||
"""
|
|
||||||
return self._rmax[:, self._mask]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def los_radial_velocity(self):
|
def los_radial_velocity(self):
|
||||||
"""
|
"""
|
||||||
|
@ -201,7 +194,6 @@ class DataLoader:
|
||||||
|
|
||||||
los_density = [None] * len(ksims)
|
los_density = [None] * len(ksims)
|
||||||
los_velocity = [None] * len(ksims)
|
los_velocity = [None] * len(ksims)
|
||||||
rmax = [None] * len(ksims)
|
|
||||||
|
|
||||||
for n, ksim in enumerate(ksims):
|
for n, ksim in enumerate(ksims):
|
||||||
nsim = nsims[ksim]
|
nsim = nsims[ksim]
|
||||||
|
@ -216,13 +208,11 @@ class DataLoader:
|
||||||
los_density[n] = f[f"density_{nsim}"][indx]
|
los_density[n] = f[f"density_{nsim}"][indx]
|
||||||
los_velocity[n] = f[f"velocity_{nsim}"][indx]
|
los_velocity[n] = f[f"velocity_{nsim}"][indx]
|
||||||
rdist = f[f"rdist_{nsim}"][...]
|
rdist = f[f"rdist_{nsim}"][...]
|
||||||
rmax[n] = f[f"rmax_{nsim}"][indx]
|
|
||||||
|
|
||||||
los_density = np.stack(los_density)
|
los_density = np.stack(los_density)
|
||||||
los_velocity = np.stack(los_velocity)
|
los_velocity = np.stack(los_velocity)
|
||||||
rmax = np.stack(rmax)
|
|
||||||
|
|
||||||
return rdist, los_density, los_velocity, rmax
|
return rdist, los_density, los_velocity
|
||||||
|
|
||||||
def _read_catalogue(self, catalogue, catalogue_fpath):
|
def _read_catalogue(self, catalogue, catalogue_fpath):
|
||||||
if catalogue == "A2":
|
if catalogue == "A2":
|
||||||
|
@ -622,9 +612,6 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
LOS density field.
|
LOS density field.
|
||||||
los_velocity : 3-dimensional array of shape (n_sims, n_objects, n_steps)
|
los_velocity : 3-dimensional array of shape (n_sims, n_objects, n_steps)
|
||||||
LOS radial velocity field.
|
LOS radial velocity field.
|
||||||
rmax : 1-dimensional array of shape (n_sims, n_objects)
|
|
||||||
Radial distance above which the underlying reconstruction is
|
|
||||||
extrapolated.
|
|
||||||
RA, dec : 1-dimensional arrays of shape (n_objects)
|
RA, dec : 1-dimensional arrays of shape (n_objects)
|
||||||
Right ascension and declination in degrees.
|
Right ascension and declination in degrees.
|
||||||
z_obs : 1-dimensional array of shape (n_objects)
|
z_obs : 1-dimensional array of shape (n_objects)
|
||||||
|
@ -643,11 +630,13 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
Catalogue kind, either "TFR", "SN", or "simple".
|
Catalogue kind, either "TFR", "SN", or "simple".
|
||||||
name : str
|
name : str
|
||||||
Name of the catalogue.
|
Name of the catalogue.
|
||||||
|
toy_selection : tuple of length 3, optional
|
||||||
|
Toy magnitude selection paramers `m1`, `m2` and `a`. Optional.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, los_density, los_velocity, rmax, RA, dec, z_obs,
|
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
|
||||||
e_zobs, calibration_params, maxmag_selection, r_xrange,
|
calibration_params, maxmag_selection, r_xrange, Omega_m,
|
||||||
Omega_m, kind, name):
|
kind, name, toy_selection=None):
|
||||||
if e_zobs is not None:
|
if e_zobs is not None:
|
||||||
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
||||||
else:
|
else:
|
||||||
|
@ -657,9 +646,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
RA = np.deg2rad(RA)
|
RA = np.deg2rad(RA)
|
||||||
dec = np.deg2rad(dec)
|
dec = np.deg2rad(dec)
|
||||||
|
|
||||||
names = ["los_density", "los_velocity", "rmax", "RA", "dec", "z_obs",
|
names = ["los_density", "los_velocity", "RA", "dec", "z_obs",
|
||||||
"e2_cz_obs"]
|
"e2_cz_obs"]
|
||||||
values = [los_density, los_velocity, rmax, RA, dec, z_obs, e2_cz_obs]
|
values = [los_density, los_velocity, RA, dec, z_obs, e2_cz_obs]
|
||||||
self._setattr_as_jax(names, values)
|
self._setattr_as_jax(names, values)
|
||||||
self._set_calibration_params(calibration_params)
|
self._set_calibration_params(calibration_params)
|
||||||
self._set_radial_spacing(r_xrange, Omega_m)
|
self._set_radial_spacing(r_xrange, Omega_m)
|
||||||
|
@ -669,6 +658,7 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
self.Omega_m = Omega_m
|
self.Omega_m = Omega_m
|
||||||
self.norm = - self.ndata * jnp.log(self.num_sims)
|
self.norm = - self.ndata * jnp.log(self.num_sims)
|
||||||
self.maxmag_selection = maxmag_selection
|
self.maxmag_selection = maxmag_selection
|
||||||
|
self.toy_selection = toy_selection
|
||||||
|
|
||||||
if kind == "TFR":
|
if kind == "TFR":
|
||||||
self.mag_min, self.mag_max = jnp.min(self.mag), jnp.max(self.mag)
|
self.mag_min, self.mag_max = jnp.min(self.mag), jnp.max(self.mag)
|
||||||
|
@ -688,6 +678,17 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
if maxmag_selection is not None and self.maxmag_selection > self.mag_max: # noqa
|
if maxmag_selection is not None and self.maxmag_selection > self.mag_max: # noqa
|
||||||
raise ValueError("The maximum magnitude cannot be larger than the selection threshold.") # noqa
|
raise ValueError("The maximum magnitude cannot be larger than the selection threshold.") # noqa
|
||||||
|
|
||||||
|
if toy_selection is not None and self.maxmag_selection is not None:
|
||||||
|
raise ValueError("`toy_selection` and `maxmag_selection` cannot be used together.") # noqa
|
||||||
|
|
||||||
|
if toy_selection is not None:
|
||||||
|
self.m1, self.m2, self.a = toy_selection
|
||||||
|
self.log_Fm = toy_log_magnitude_selection(
|
||||||
|
self.mag, self.m1, self.m2, self.a)
|
||||||
|
|
||||||
|
if toy_selection is not None and self.kind != "TFR":
|
||||||
|
raise ValueError("Toy selection is only implemented for TFRs.")
|
||||||
|
|
||||||
def __call__(self, field_calibration_params, distmod_params,
|
def __call__(self, field_calibration_params, distmod_params,
|
||||||
inference_method):
|
inference_method):
|
||||||
if inference_method not in ["mike", "bayes"]:
|
if inference_method not in ["mike", "bayes"]:
|
||||||
|
@ -772,12 +773,30 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
|
|
||||||
mag_true, eta_true = x_true[..., 0], x_true[..., 1]
|
mag_true, eta_true = x_true[..., 0], x_true[..., 1]
|
||||||
# Log-likelihood of the observed magnitudes.
|
# Log-likelihood of the observed magnitudes.
|
||||||
if self.maxmag_selection is None:
|
if self.maxmag_selection is not None:
|
||||||
ll0 += jnp.sum(normal_logpdf(
|
|
||||||
self.mag, mag_true, self.e_mag))
|
|
||||||
else:
|
|
||||||
ll0 += jnp.sum(upper_truncated_normal_logpdf(
|
ll0 += jnp.sum(upper_truncated_normal_logpdf(
|
||||||
self.mag, mag_true, self.e_mag, self.maxmag_selection))
|
self.mag, mag_true, self.e_mag, self.maxmag_selection))
|
||||||
|
elif self.toy_selection is not None:
|
||||||
|
ll_mag = self.log_Fm
|
||||||
|
ll_mag += normal_logpdf(self.mag, mag_true, self.e_mag)
|
||||||
|
|
||||||
|
# Normalization per datapoint, initially (ndata, nxrange)
|
||||||
|
mu_start = mag_true - 5 * self.e_mag
|
||||||
|
mu_end = mag_true + 5 * self.e_mag
|
||||||
|
# 100 is a reasonable and sufficient choice.
|
||||||
|
mu_xrange = jnp.linspace(mu_start, mu_end, 100).T
|
||||||
|
|
||||||
|
norm = toy_log_magnitude_selection(
|
||||||
|
mu_xrange, self.m1, self.m2, self.a)
|
||||||
|
norm = norm + normal_logpdf(
|
||||||
|
mu_xrange, mag_true[:, None], self.e_mag[:, None])
|
||||||
|
# Now integrate over the magnitude range.
|
||||||
|
norm = simpson(jnp.exp(norm), x=mu_xrange, axis=-1)
|
||||||
|
|
||||||
|
ll0 += jnp.sum(ll_mag - jnp.log(norm))
|
||||||
|
else:
|
||||||
|
ll0 += jnp.sum(normal_logpdf(
|
||||||
|
self.mag, mag_true, self.e_mag))
|
||||||
|
|
||||||
# Log-likelihood of the observed linewidths.
|
# Log-likelihood of the observed linewidths.
|
||||||
ll0 += jnp.sum(normal_logpdf(eta_true, self.eta, self.e_eta))
|
ll0 += jnp.sum(normal_logpdf(eta_true, self.eta, self.e_eta))
|
||||||
|
@ -876,7 +895,8 @@ def PV_validation_model(models, distmod_hyperparams_per_model,
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None):
|
def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None,
|
||||||
|
toy_selection=None):
|
||||||
"""
|
"""
|
||||||
Get a model and extract the relevant data from the loader.
|
Get a model and extract the relevant data from the loader.
|
||||||
|
|
||||||
|
@ -890,6 +910,9 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None):
|
||||||
Maximum observed redshift in the CMB frame to include.
|
Maximum observed redshift in the CMB frame to include.
|
||||||
maxmag_selection : float, optional
|
maxmag_selection : float, optional
|
||||||
Maximum magnitude selection threshold.
|
Maximum magnitude selection threshold.
|
||||||
|
toy_selection : tuple of length 3, optional
|
||||||
|
Toy magnitude selection paramers `m1`, `m2` and `a` for TFRs of the
|
||||||
|
Boubel+24 model.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -899,7 +922,6 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None):
|
||||||
|
|
||||||
los_overdensity = loader.los_density
|
los_overdensity = loader.los_density
|
||||||
los_velocity = loader.los_radial_velocity
|
los_velocity = loader.los_radial_velocity
|
||||||
rmax = loader.rmax
|
|
||||||
kind = loader._catname
|
kind = loader._catname
|
||||||
|
|
||||||
if maxmag_selection is not None and kind != "2MTF":
|
if maxmag_selection is not None and kind != "2MTF":
|
||||||
|
@ -917,7 +939,7 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None):
|
||||||
"e_c": e_c[mask]}
|
"e_c": e_c[mask]}
|
||||||
|
|
||||||
model = PV_LogLikelihood(
|
model = PV_LogLikelihood(
|
||||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
los_overdensity[:, mask], los_velocity[:, mask],
|
||||||
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
|
RA[mask], dec[mask], zCMB[mask], e_zCMB, calibration_params,
|
||||||
maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind)
|
maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind)
|
||||||
elif "Pantheon+" in kind:
|
elif "Pantheon+" in kind:
|
||||||
|
@ -945,20 +967,27 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None):
|
||||||
"e_mag": e_mB[mask], "e_x1": e_x1[mask],
|
"e_mag": e_mB[mask], "e_x1": e_x1[mask],
|
||||||
"e_c": e_c[mask]}
|
"e_c": e_c[mask]}
|
||||||
model = PV_LogLikelihood(
|
model = PV_LogLikelihood(
|
||||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
los_overdensity[:, mask], los_velocity[:, mask],
|
||||||
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
|
RA[mask], dec[mask], zCMB[mask], e_zCMB[mask], calibration_params,
|
||||||
maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind)
|
maxmag_selection, loader.rdist, loader._Omega_m, "SN", name=kind)
|
||||||
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
|
elif kind in ["SFI_gals", "2MTF", "SFI_gals_masked"]:
|
||||||
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
|
keys = ["RA", "DEC", "z_CMB", "mag", "eta", "e_mag", "e_eta"]
|
||||||
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
|
RA, dec, zCMB, mag, eta, e_mag, e_eta = (loader.cat[k] for k in keys)
|
||||||
|
|
||||||
|
if kind == "SFI_gals" and toy_selection is not None:
|
||||||
|
if len(toy_selection) != 3:
|
||||||
|
raise ValueError("Toy selection must be a tuple with 3 elements.") # noqa
|
||||||
|
m1, m2, a = toy_selection
|
||||||
|
fprint(f"using toy selection with m1 = {m1}, m2 = {m2}, a = {a}.")
|
||||||
|
|
||||||
mask = (zCMB < zcmb_max) & (zCMB > zcmb_min)
|
mask = (zCMB < zcmb_max) & (zCMB > zcmb_min)
|
||||||
calibration_params = {"mag": mag[mask], "eta": eta[mask],
|
calibration_params = {"mag": mag[mask], "eta": eta[mask],
|
||||||
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
|
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
|
||||||
model = PV_LogLikelihood(
|
model = PV_LogLikelihood(
|
||||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
los_overdensity[:, mask], los_velocity[:, mask],
|
||||||
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
|
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
|
||||||
maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind)
|
maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind,
|
||||||
|
toy_selection=toy_selection)
|
||||||
elif "CF4_TFR_" in kind:
|
elif "CF4_TFR_" in kind:
|
||||||
# The full name can be e.g. "CF4_TFR_not2MTForSFI_i" or "CF4_TFR_i".
|
# The full name can be e.g. "CF4_TFR_not2MTForSFI_i" or "CF4_TFR_i".
|
||||||
band = kind.split("_")[-1]
|
band = kind.split("_")[-1]
|
||||||
|
@ -995,7 +1024,7 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None):
|
||||||
calibration_params = {"mag": mag[mask], "eta": eta[mask],
|
calibration_params = {"mag": mag[mask], "eta": eta[mask],
|
||||||
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
|
"e_mag": e_mag[mask], "e_eta": e_eta[mask]}
|
||||||
model = PV_LogLikelihood(
|
model = PV_LogLikelihood(
|
||||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
los_overdensity[:, mask], los_velocity[:, mask],
|
||||||
RA[mask], dec[mask], z_obs[mask], None, calibration_params,
|
RA[mask], dec[mask], z_obs[mask], None, calibration_params,
|
||||||
maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind)
|
maxmag_selection, loader.rdist, loader._Omega_m, "TFR", name=kind)
|
||||||
elif kind in ["CF4_GroupAll"]:
|
elif kind in ["CF4_GroupAll"]:
|
||||||
|
@ -1011,7 +1040,7 @@ def get_model(loader, zcmb_min=0.0, zcmb_max=None, maxmag_selection=None):
|
||||||
|
|
||||||
calibration_params = {"mu": mu[mask], "e_mu": e_mu[mask]}
|
calibration_params = {"mu": mu[mask], "e_mu": e_mu[mask]}
|
||||||
model = PV_LogLikelihood(
|
model = PV_LogLikelihood(
|
||||||
los_overdensity[:, mask], los_velocity[:, mask], rmax[:, mask],
|
los_overdensity[:, mask], los_velocity[:, mask],
|
||||||
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
|
RA[mask], dec[mask], zCMB[mask], None, calibration_params,
|
||||||
maxmag_selection, loader.rdist, loader._Omega_m, "simple",
|
maxmag_selection, loader.rdist, loader._Omega_m, "simple",
|
||||||
name=kind)
|
name=kind)
|
||||||
|
|
69
csiborgtools/flow/selection.py
Normal file
69
csiborgtools/flow/selection.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (C) 2024 Richard Stiskalek
|
||||||
|
# This program is free software; you can redistribute it and/or modify it
|
||||||
|
# under the terms of the GNU General Public License as published by the
|
||||||
|
# Free Software Foundation; either version 3 of the License, or (at your
|
||||||
|
# option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful, but
|
||||||
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
|
||||||
|
# Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License along
|
||||||
|
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
|
"""Selection functions for peculiar velocities."""
|
||||||
|
import numpy as np
|
||||||
|
from jax import numpy as jnp
|
||||||
|
from scipy.integrate import quad
|
||||||
|
from scipy.optimize import minimize
|
||||||
|
|
||||||
|
|
||||||
|
class ToyMagnitudeSelection:
|
||||||
|
"""
|
||||||
|
Toy magnitude selection according to Boubel et al 2024.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def log_true_pdf(self, m, m1):
|
||||||
|
"""Unnormalized `true' PDF."""
|
||||||
|
return 0.6 * (m - m1)
|
||||||
|
|
||||||
|
def log_selection_function(self, m, m1, m2, a):
|
||||||
|
return np.where(m <= m1,
|
||||||
|
0,
|
||||||
|
a * (m - m2)**2 - a * (m1 - m2)**2 - 0.6 * (m - m1))
|
||||||
|
|
||||||
|
def log_observed_pdf(self, m, m1, m2, a):
|
||||||
|
# Calculate the normalization constant
|
||||||
|
f = lambda m: 10**(self.log_true_pdf(m, m1) # noqa
|
||||||
|
+ self.log_selection_function(m, m1, m2, a))
|
||||||
|
mmin, mmax = 0, 25
|
||||||
|
norm = quad(f, mmin, mmax)[0]
|
||||||
|
|
||||||
|
return (self.log_true_pdf(m, m1)
|
||||||
|
+ self.log_selection_function(m, m1, m2, a)
|
||||||
|
- np.log10(norm))
|
||||||
|
|
||||||
|
def fit(self, mag):
|
||||||
|
|
||||||
|
def loss(x):
|
||||||
|
m1, m2, a = x
|
||||||
|
|
||||||
|
if a >= 0:
|
||||||
|
return np.inf
|
||||||
|
|
||||||
|
return -np.sum(self.log_observed_pdf(mag, m1, m2, a))
|
||||||
|
|
||||||
|
x0 = [12.0, 12.5, -0.1]
|
||||||
|
return minimize(loss, x0, method="Nelder-Mead")
|
||||||
|
|
||||||
|
|
||||||
|
def toy_log_magnitude_selection(mag, m1, m2, a):
|
||||||
|
"""JAX implementation of `ToyMagnitudeSelection` but natural logarithm."""
|
||||||
|
return jnp.log(10) * jnp.where(
|
||||||
|
mag <= m1,
|
||||||
|
0,
|
||||||
|
a * (mag - m2)**2 - a * (m1 - m2)**2 - 0.6 * (mag - m1))
|
|
@ -492,6 +492,11 @@ def dict_samples_to_array(samples):
|
||||||
for i in range(value.shape[-1]):
|
for i in range(value.shape[-1]):
|
||||||
data.append(value[:, i])
|
data.append(value[:, i])
|
||||||
names.append(f"{key}_{i}")
|
names.append(f"{key}_{i}")
|
||||||
|
elif value.ndim == 3:
|
||||||
|
for i in range(value.shape[-1]):
|
||||||
|
for j in range(value.shape[-2]):
|
||||||
|
data.append(value[:, j, i])
|
||||||
|
names.append(f"{key}_{j}_{i}")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid dimensionality of samples to stack.")
|
raise ValueError("Invalid dimensionality of samples to stack.")
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ MAS="SPH"
|
||||||
grid=1024
|
grid=1024
|
||||||
|
|
||||||
|
|
||||||
for simname in "Carrick2015"; do
|
for simname in "Lilow2024"; do
|
||||||
for catalogue in "CF4_TFR"; do
|
for catalogue in "CF4_TFR"; do
|
||||||
pythoncm="$env $file --catalogue $catalogue --nsims $nsims --simname $simname --MAS $MAS --grid $grid"
|
pythoncm="$env $file --catalogue $catalogue --nsims $nsims --simname $simname --MAS $MAS --grid $grid"
|
||||||
if [ $on_login -eq 1 ]; then
|
if [ $on_login -eq 1 ]; then
|
||||||
|
|
|
@ -72,7 +72,7 @@ def print_variables(names, variables):
|
||||||
print(flush=True)
|
print(flush=True)
|
||||||
|
|
||||||
|
|
||||||
def get_models(get_model_kwargs, verbose=True):
|
def get_models(get_model_kwargs, toy_selection, verbose=True):
|
||||||
"""Load the data and create the NumPyro models."""
|
"""Load the data and create the NumPyro models."""
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
||||||
|
@ -110,7 +110,8 @@ def get_models(get_model_kwargs, verbose=True):
|
||||||
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
|
loader = csiborgtools.flow.DataLoader(ARGS.simname, nsim_iterator,
|
||||||
cat, fpath, paths,
|
cat, fpath, paths,
|
||||||
ksmooth=ARGS.ksmooth)
|
ksmooth=ARGS.ksmooth)
|
||||||
models[i] = csiborgtools.flow.get_model(loader, **get_model_kwargs)
|
models[i] = csiborgtools.flow.get_model(
|
||||||
|
loader, toy_selection=toy_selection[i], **get_model_kwargs)
|
||||||
|
|
||||||
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
|
print(f"\n{'Num. radial steps':<20} {len(loader.rdist)}\n", flush=True)
|
||||||
return models
|
return models
|
||||||
|
@ -127,7 +128,7 @@ def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
||||||
|
|
||||||
|
|
||||||
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
calculate_evidence, nchains_harmonic, epoch_num, kwargs_print):
|
calculate_harmonic, nchains_harmonic, epoch_num, kwargs_print):
|
||||||
"""Run the NumPyro model and save output to a file."""
|
"""Run the NumPyro model and save output to a file."""
|
||||||
try:
|
try:
|
||||||
ndata = sum(model.ndata for model in model_kwargs["models"])
|
ndata = sum(model.ndata for model in model_kwargs["models"])
|
||||||
|
@ -148,12 +149,12 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
print(f"{'AIC':<20} {AIC}")
|
print(f"{'AIC':<20} {AIC}")
|
||||||
mcmc.print_summary()
|
mcmc.print_summary()
|
||||||
|
|
||||||
if calculate_evidence:
|
if calculate_harmonic:
|
||||||
print("Calculating the evidence using `harmonic`.", flush=True)
|
print("Calculating the evidence using `harmonic`.", flush=True)
|
||||||
neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence(
|
neg_ln_evidence, neg_ln_evidence_err = get_harmonic_evidence(
|
||||||
samples, log_posterior, nchains_harmonic, epoch_num)
|
samples, log_posterior, nchains_harmonic, epoch_num)
|
||||||
print(f"{'-ln(Z)':<20} {neg_ln_evidence}")
|
print(f"{'-ln(Z_h)':<20} {neg_ln_evidence}")
|
||||||
print(f"{'-ln(Z) error':<20} {neg_ln_evidence_err}")
|
print(f"{'-ln(Z_h) error':<20} {neg_ln_evidence_err}")
|
||||||
else:
|
else:
|
||||||
neg_ln_evidence = jax.numpy.nan
|
neg_ln_evidence = jax.numpy.nan
|
||||||
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
neg_ln_evidence_err = (jax.numpy.nan, jax.numpy.nan)
|
||||||
|
@ -180,8 +181,8 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
grp = f.create_group("gof")
|
grp = f.create_group("gof")
|
||||||
grp.create_dataset("BIC", data=BIC)
|
grp.create_dataset("BIC", data=BIC)
|
||||||
grp.create_dataset("AIC", data=AIC)
|
grp.create_dataset("AIC", data=AIC)
|
||||||
grp.create_dataset("neg_lnZ", data=neg_ln_evidence)
|
grp.create_dataset("neg_lnZ_harmonic", data=neg_ln_evidence)
|
||||||
grp.create_dataset("neg_lnZ_err", data=neg_ln_evidence_err)
|
grp.create_dataset("neg_lnZ_harmonic_err", data=neg_ln_evidence_err)
|
||||||
|
|
||||||
fname_summary = fname.replace(".hdf5", ".txt")
|
fname_summary = fname.replace(".hdf5", ".txt")
|
||||||
print(f"Saving summary to `{fname_summary}`.")
|
print(f"Saving summary to `{fname_summary}`.")
|
||||||
|
@ -206,7 +207,7 @@ def run_model(model, nsteps, nburn, model_kwargs, out_folder, sample_beta,
|
||||||
# Command line interface #
|
# Command line interface #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
def get_distmod_hyperparams(catalogue, sample_alpha):
|
def get_distmod_hyperparams(catalogue, sample_alpha, sample_mag_dipole):
|
||||||
alpha_min = -1.0
|
alpha_min = -1.0
|
||||||
alpha_max = 3.0
|
alpha_max = 3.0
|
||||||
|
|
||||||
|
@ -225,7 +226,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
|
||||||
"c_mean": 0., "c_std": 20.0,
|
"c_mean": 0., "c_std": 20.0,
|
||||||
"sample_curvature": False,
|
"sample_curvature": False,
|
||||||
"a_dipole_mean": 0., "a_dipole_std": 1.0,
|
"a_dipole_mean": 0., "a_dipole_std": 1.0,
|
||||||
"sample_a_dipole": True,
|
"sample_a_dipole": sample_mag_dipole,
|
||||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||||
"sample_alpha": sample_alpha,
|
"sample_alpha": sample_alpha,
|
||||||
}
|
}
|
||||||
|
@ -233,7 +234,7 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
|
||||||
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
return {"e_mu_min": 0.001, "e_mu_max": 1.0,
|
||||||
"dmu_min": -3.0, "dmu_max": 3.0,
|
"dmu_min": -3.0, "dmu_max": 3.0,
|
||||||
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
|
"dmu_dipole_mean": 0., "dmu_dipole_std": 1.0,
|
||||||
"sample_dmu_dipole": True,
|
"sample_dmu_dipole": sample_mag_dipole,
|
||||||
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
"alpha_min": alpha_min, "alpha_max": alpha_max,
|
||||||
"sample_alpha": sample_alpha,
|
"sample_alpha": sample_alpha,
|
||||||
}
|
}
|
||||||
|
@ -241,6 +242,16 @@ def get_distmod_hyperparams(catalogue, sample_alpha):
|
||||||
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_toy_selection(toy_selection, catalogue):
|
||||||
|
if not toy_selection:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if catalogue == "SFI_gals":
|
||||||
|
return [1.221e+01, 1.297e+01, -2.708e-01]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported catalogue: `{ARGS.catalogue}`.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
out_folder = "/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity" # noqa
|
out_folder = "/mnt/extraspace/rstiskalek/csiborg_postprocessing/peculiar_velocity" # noqa
|
||||||
|
@ -251,18 +262,23 @@ if __name__ == "__main__":
|
||||||
# Fixed user parameters #
|
# Fixed user parameters #
|
||||||
###########################################################################
|
###########################################################################
|
||||||
|
|
||||||
nsteps = 1500
|
nsteps = 1000
|
||||||
nburn = 1000
|
nburn = 500
|
||||||
zcmb_min = 0
|
zcmb_min = 0
|
||||||
zcmb_max = 0.05
|
zcmb_max = 0.05
|
||||||
calculate_evidence = False
|
|
||||||
nchains_harmonic = 10
|
nchains_harmonic = 10
|
||||||
num_epochs = 30
|
num_epochs = 50
|
||||||
inference_method = "mike"
|
inference_method = "bayes"
|
||||||
|
calculate_harmonic = True if inference_method == "mike" else False
|
||||||
maxmag_selection = None
|
maxmag_selection = None
|
||||||
sample_alpha = True
|
sample_alpha = False
|
||||||
sample_beta = True
|
sample_beta = True
|
||||||
sample_Vmono = False
|
sample_Vmono = False
|
||||||
|
sample_mag_dipole = False
|
||||||
|
toy_selection = True
|
||||||
|
|
||||||
|
if toy_selection and inference_method == "mike":
|
||||||
|
raise ValueError("Toy selection is not supported with `mike` inference.") # noqa
|
||||||
|
|
||||||
if nsteps % nchains_harmonic != 0:
|
if nsteps % nchains_harmonic != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -272,10 +288,12 @@ if __name__ == "__main__":
|
||||||
"zcmb_min": zcmb_min,
|
"zcmb_min": zcmb_min,
|
||||||
"zcmb_max": zcmb_max,
|
"zcmb_max": zcmb_max,
|
||||||
"maxmag_selection": maxmag_selection,
|
"maxmag_selection": maxmag_selection,
|
||||||
"calculate_evidence": calculate_evidence,
|
"calculate_harmonic": calculate_harmonic,
|
||||||
"nchains_harmonic": nchains_harmonic,
|
"nchains_harmonic": nchains_harmonic,
|
||||||
"num_epochs": num_epochs,
|
"num_epochs": num_epochs,
|
||||||
"inference_method": inference_method}
|
"inference_method": inference_method,
|
||||||
|
"sample_mag_dipole": sample_mag_dipole,
|
||||||
|
"toy_selection": toy_selection}
|
||||||
print_variables(main_params.keys(), main_params.values())
|
print_variables(main_params.keys(), main_params.values())
|
||||||
|
|
||||||
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
|
calibration_hyperparams = {"Vext_min": -1000, "Vext_max": 1000,
|
||||||
|
@ -290,7 +308,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
distmod_hyperparams_per_catalogue = []
|
distmod_hyperparams_per_catalogue = []
|
||||||
for cat in ARGS.catalogue:
|
for cat in ARGS.catalogue:
|
||||||
x = get_distmod_hyperparams(cat, sample_alpha)
|
x = get_distmod_hyperparams(cat, sample_alpha, sample_mag_dipole)
|
||||||
print(f"\n{cat} hyperparameters:")
|
print(f"\n{cat} hyperparameters:")
|
||||||
print_variables(x.keys(), x.values())
|
print_variables(x.keys(), x.values())
|
||||||
distmod_hyperparams_per_catalogue.append(x)
|
distmod_hyperparams_per_catalogue.append(x)
|
||||||
|
@ -301,7 +319,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
get_model_kwargs = {"zcmb_min": zcmb_min, "zcmb_max": zcmb_max,
|
get_model_kwargs = {"zcmb_min": zcmb_min, "zcmb_max": zcmb_max,
|
||||||
"maxmag_selection": maxmag_selection}
|
"maxmag_selection": maxmag_selection}
|
||||||
models = get_models(get_model_kwargs, )
|
|
||||||
|
toy_selection = [get_toy_selection(toy_selection, cat)
|
||||||
|
for cat in ARGS.catalogue]
|
||||||
|
|
||||||
|
models = get_models(get_model_kwargs, toy_selection)
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"models": models,
|
"models": models,
|
||||||
"field_calibration_hyperparams": calibration_hyperparams,
|
"field_calibration_hyperparams": calibration_hyperparams,
|
||||||
|
@ -312,5 +334,5 @@ if __name__ == "__main__":
|
||||||
model = csiborgtools.flow.PV_validation_model
|
model = csiborgtools.flow.PV_validation_model
|
||||||
|
|
||||||
run_model(model, nsteps, nburn, model_kwargs, out_folder,
|
run_model(model, nsteps, nburn, model_kwargs, out_folder,
|
||||||
calibration_hyperparams["sample_beta"], calculate_evidence,
|
calibration_hyperparams["sample_beta"], calculate_harmonic,
|
||||||
nchains_harmonic, num_epochs, kwargs_print)
|
nchains_harmonic, num_epochs, kwargs_print)
|
||||||
|
|
|
@ -39,7 +39,7 @@ fi
|
||||||
|
|
||||||
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
|
# for simname in "Lilow2024" "CF4" "CF4gp" "csiborg1" "csiborg2_main" "csiborg2X"; do
|
||||||
for simname in "Carrick2015"; do
|
for simname in "Carrick2015"; do
|
||||||
for catalogue in "CF4_GroupAll"; do
|
for catalogue in "SFI_gals"; do
|
||||||
# for catalogue in "CF4_TFR_i"; do
|
# for catalogue in "CF4_TFR_i"; do
|
||||||
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
|
# for ksim in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20; do
|
||||||
for ksim in "none"; do
|
for ksim in "none"; do
|
||||||
|
|
60
scripts/flow/test_harmonic.py
Normal file
60
scripts/flow/test_harmonic.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
from argparse import ArgumentParser, ArgumentTypeError
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--device", type=str, default="cpu",
|
||||||
|
help="Device to use.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
ARGS = parse_args()
|
||||||
|
# This must be done before we import JAX etc.
|
||||||
|
from numpyro import set_host_device_count, set_platform # noqa
|
||||||
|
|
||||||
|
set_platform(ARGS.device) # noqa
|
||||||
|
|
||||||
|
from jax import numpy as jnp # noqa
|
||||||
|
import numpy as np # noqa
|
||||||
|
import csiborgtools # noqa
|
||||||
|
from scipy.stats import multivariate_normal # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def get_harmonic_evidence(samples, log_posterior, nchains_harmonic, epoch_num):
|
||||||
|
"""Compute evidence using the `harmonic` package."""
|
||||||
|
data, names = csiborgtools.dict_samples_to_array(samples)
|
||||||
|
data = data.reshape(nchains_harmonic, -1, len(names))
|
||||||
|
log_posterior = log_posterior.reshape(nchains_harmonic, -1)
|
||||||
|
|
||||||
|
return csiborgtools.harmonic_evidence(
|
||||||
|
data, log_posterior, return_flow_samples=False, epochs_num=epoch_num)
|
||||||
|
|
||||||
|
|
||||||
|
ndim = 250
|
||||||
|
nsamples = 100_000
|
||||||
|
nchains_split = 10
|
||||||
|
loc = jnp.zeros(ndim)
|
||||||
|
cov = jnp.eye(ndim)
|
||||||
|
|
||||||
|
|
||||||
|
gen = np.random.default_rng()
|
||||||
|
X = gen.multivariate_normal(loc, cov, size=nsamples)
|
||||||
|
samples = {f"x_{i}": X[:, i] for i in range(ndim)}
|
||||||
|
logprob = multivariate_normal(loc, cov).logpdf(X)
|
||||||
|
|
||||||
|
neg_lnZ_laplace, neg_lnZ_laplace_error = csiborgtools.laplace_evidence(
|
||||||
|
samples, logprob, nchains_split)
|
||||||
|
print(f"neg_lnZ_laplace: {neg_lnZ_laplace} +/- {neg_lnZ_laplace_error}")
|
||||||
|
|
||||||
|
|
||||||
|
neg_lnZ_harmonic, neg_lnZ_harmonic_error = get_harmonic_evidence(
|
||||||
|
samples, logprob, nchains_split, epoch_num=30)
|
||||||
|
print(f"neg_lnZ_harmonic: {neg_lnZ_harmonic} +/- {neg_lnZ_harmonic_error}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue