Add p(z_cosmo | z_obs) (#122)

* Add import

* Add draft of the p(zcosmo)

* Add option for additional uncertainty

* Variable renaming

* Add zcosmo predict for flow calibration

* Add flow map notebook

* Update notebook

* Add posterior mean & std

* Edit docstring
This commit is contained in:
Richard Stiskalek 2024-04-01 15:44:20 +02:00 committed by GitHub
parent e3a645c935
commit 8e9645e202
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 584 additions and 8 deletions

View file

@ -17,4 +17,5 @@ from .flow_model import (DataLoader, radial_velocity_los, dist2redshift,
SD_PV_validation_model, SN_PV_validation_model, # noqa SD_PV_validation_model, SN_PV_validation_model, # noqa
TF_PV_validation_model, radec_to_galactic, # noqa TF_PV_validation_model, radec_to_galactic, # noqa
sample_prior, make_loss, get_model, # noqa sample_prior, make_loss, get_model, # noqa
optimize_model_with_jackknife, distmodulus2dist) # noqa optimize_model_with_jackknife, distmodulus2dist, # noqa
Observed2CosmologicalRedshift) # noqa

View file

@ -45,6 +45,7 @@ from tqdm import trange
from ..params import simname2Omega_m from ..params import simname2Omega_m
SPEED_OF_LIGHT = 299792.458 # km / s SPEED_OF_LIGHT = 299792.458 # km / s
H0 = 100 # km / s / Mpc
def t(): def t():
@ -143,7 +144,7 @@ class DataLoader:
# Normalize the CSiBORG density by the mean matter density # Normalize the CSiBORG density by the mean matter density
if "csiborg" in simname: if "csiborg" in simname:
cosmo = FlatLambdaCDM(H0=100, Om0=self._Omega_m) cosmo = FlatLambdaCDM(H0=H0, Om0=self._Omega_m)
mean_rho_matter = cosmo.critical_density0.to("Msun/kpc^3").value mean_rho_matter = cosmo.critical_density0.to("Msun/kpc^3").value
mean_rho_matter *= self._Omega_m mean_rho_matter *= self._Omega_m
self._los_density /= mean_rho_matter self._los_density /= mean_rho_matter
@ -402,11 +403,50 @@ def dist2redshift(dist, Omega_m):
------- -------
float or 1-dimensional array float or 1-dimensional array
""" """
H0 = 100
eta = 3 * Omega_m / 2 eta = 3 * Omega_m / 2
return 1 / eta * (1 - (1 - 2 * H0 * dist / SPEED_OF_LIGHT * eta)**0.5) return 1 / eta * (1 - (1 - 2 * H0 * dist / SPEED_OF_LIGHT * eta)**0.5)
def redshift2dist(z, Omega_m):
"""
Convert cosmological redshift to comoving distance if the Universe is
flat and z << 1.
Parameters
----------
z : float or 1-dimensional array
Cosmological redshift.
Omega_m : float
Matter density parameter.
Returns
-------
float or 1-dimensional array
"""
q0 = 3 * Omega_m / 2 - 1
return SPEED_OF_LIGHT * z / (2 * H0) * (2 - z * (1 + q0))
def gradient_redshift2dist(z, Omega_m):
"""
Gradient of the redshift to comoving distance conversion if the Universe is
flat and z << 1.
Parameters
----------
z : float or 1-dimensional array
Cosmological redshift.
Omega_m : float
Matter density parameter.
Returns
-------
float or 1-dimensional array
"""
q0 = 3 * Omega_m / 2 - 1
return SPEED_OF_LIGHT / H0 * (1 - z * (1 + q0))
def dist2distmodulus(dist, Omega_m): def dist2distmodulus(dist, Omega_m):
""" """
Convert comoving distance to distance modulus, assuming z << 1. Convert comoving distance to distance modulus, assuming z << 1.
@ -457,7 +497,7 @@ def distmodulus2dist(mu, Omega_m, ninterp=10000, zmax=0.1, mu2comoving=None,
""" """
if mu2comoving is None: if mu2comoving is None:
zrange = np.linspace(1e-15, zmax, ninterp) zrange = np.linspace(1e-15, zmax, ninterp)
cosmo = FlatLambdaCDM(H0=100, Om0=Omega_m) cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)
mu2comoving = interp1d( mu2comoving = interp1d(
cosmo.distmod(zrange).value, cosmo.comoving_distance(zrange).value, cosmo.distmod(zrange).value, cosmo.comoving_distance(zrange).value,
kind="cubic") kind="cubic")
@ -468,6 +508,45 @@ def distmodulus2dist(mu, Omega_m, ninterp=10000, zmax=0.1, mu2comoving=None,
return mu2comoving(mu) return mu2comoving(mu)
def distmodulus2redsfhit(mu, Omega_m, ninterp=10000, zmax=0.1, mu2z=None,
return_interpolator=False):
"""
Convert distance modulus to cosmological redshift. Note that this is a
costly implementation, as it builts up the interpolator every time it is
called unless it is provided.
Parameters
----------
mu : float or 1-dimensional array
Distance modulus.
Omega_m : float
Matter density parameter.
ninterp : int, optional
Number of points to interpolate the mapping from distance modulus to
comoving distance.
zmax : float, optional
Maximum redshift for the interpolation.
mu2z : callable, optional
Interpolator from distance modulus to cosmological redsfhit. If not
provided, it is built up every time the function is called.
return_interpolator : bool, optional
Whether to return the interpolator as well.
Returns
-------
float (or 1-dimensional array) and callable (optional)
"""
if mu2z is None:
zrange = np.linspace(1e-15, zmax, ninterp)
cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)
mu2z = interp1d(cosmo.distmod(zrange).value, zrange, kind="cubic")
if return_interpolator:
return mu2z(mu), mu2z
return mu2z(mu)
def project_Vext(Vext_x, Vext_y, Vext_z, RA, dec): def project_Vext(Vext_x, Vext_y, Vext_z, RA, dec):
""" """
Project the external velocity onto the line of sight along direction Project the external velocity onto the line of sight along direction
@ -556,7 +635,7 @@ def calculate_ptilde_wo_bias(xrange, mu, err, r_squared_xrange=None,
return ptilde return ptilde
def calculate_ll_zobs(zobs, zobs_pred, sigma_v): def calculate_likelihood_zobs(zobs, zobs_pred, sigma_v):
""" """
Calculate the likelihood of the observed redshift given the predicted Calculate the likelihood of the observed redshift given the predicted
redshift. redshift.
@ -688,6 +767,10 @@ class BaseFlowValidationModel(ABC):
return mu, std return mu, std
@abstractmethod
def predict_zcosmo_from_calibration(self, **kwargs):
pass
@abstractmethod @abstractmethod
def __call__(self, **kwargs): def __call__(self, **kwargs):
pass pass
@ -768,6 +851,9 @@ class SD_PV_validation_model(BaseFlowValidationModel):
def predict_zobs_single(self, **kwargs): def predict_zobs_single(self, **kwargs):
raise NotImplementedError("This method is not implemented yet.") raise NotImplementedError("This method is not implemented yet.")
def predict_zcosmo_from_calibration(self, **kwargs):
raise NotImplementedError("This method is not implemented yet.")
def __call__(self, sample_alpha=True, sample_beta=True): def __call__(self, sample_alpha=True, sample_beta=True):
""" """
The simple distance NumPyro PV validation model. The simple distance NumPyro PV validation model.
@ -804,7 +890,8 @@ class SD_PV_validation_model(BaseFlowValidationModel):
# Calculate p(z_obs) and multiply it by p(r) # Calculate p(z_obs) and multiply it by p(r)
zobs_pred = self._f_zobs(beta, Vext_rad[i], self._los_velocity[i]) zobs_pred = self._f_zobs(beta, Vext_rad[i], self._los_velocity[i])
ptilde *= calculate_ll_zobs(self._z_obs[i], zobs_pred, sigma_v) ptilde *= calculate_likelihood_zobs(
self._z_obs[i], zobs_pred, sigma_v)
return ll + jnp.log(self._f_simps(ptilde) / pnorm), None return ll + jnp.log(self._f_simps(ptilde) / pnorm), None
@ -924,6 +1011,43 @@ class SN_PV_validation_model(BaseFlowValidationModel):
return (self._e2_mB + alpha_cal**2 * self._e2_x1 return (self._e2_mB + alpha_cal**2 * self._e2_x1
+ beta_cal**2 * self._e2_c + e_mu_intrinsic**2) + beta_cal**2 * self._e2_c + e_mu_intrinsic**2)
def predict_zcosmo_from_calibration(self, mag_cal, alpha_cal, beta_cal,
to_jit=True):
"""
Predict the cosmological redshift given the SALT2 calibration
parameters.
Parameters
----------
mag_cal, alpha_cal, beta_cal : floats
SALT2 calibration parameters.
to_jit : bool, optional
Whether to JIT compile the distance modulus function.
Returns
-------
zcosmo_mean : 1-dimensional array
Mean of the predicted redshifts.
zcosmo_std : 1-dimensional array
Standard deviation of the predicted redshifts.
"""
if not ((mag_cal.shape == alpha_cal.shape == beta_cal.shape) and mag_cal.ndim == 1): # noqa
raise ValueError("The shape of calibration parameters must be 1D and equal.") # noqa
fmu = jit(self.mu) if to_jit else self.mu
zcosmo = np.empty((len(mag_cal), self.ndata), dtype=np.float32)
mu2z = None
for i in trange(len(mag_cal)):
x = fmu(mag_cal[i], alpha_cal[i], beta_cal[i])
zcosmo[i], mu2z = distmodulus2redsfhit(x, self._Omega_m, mu2z=mu2z,
return_interpolator=True)
zcosmo_mean = zcosmo.mean(axis=0)
zcosmo_std = zcosmo.std(axis=0)
return zcosmo_mean, zcosmo_std
def predict_zobs_single(self, Vext_x, Vext_y, Vext_z, alpha, beta, def predict_zobs_single(self, Vext_x, Vext_y, Vext_z, alpha, beta,
e_mu_intrinsic, mag_cal, alpha_cal, beta_cal, e_mu_intrinsic, mag_cal, alpha_cal, beta_cal,
**kwargs): **kwargs):
@ -1011,7 +1135,8 @@ class SN_PV_validation_model(BaseFlowValidationModel):
# Calculate p(z_obs) and multiply it by p(r) # Calculate p(z_obs) and multiply it by p(r)
zobs_pred = self._f_zobs(beta, Vext_rad[i], self._los_velocity[i]) zobs_pred = self._f_zobs(beta, Vext_rad[i], self._los_velocity[i])
ptilde *= calculate_ll_zobs(self._z_obs[i], zobs_pred, sigma_v) ptilde *= calculate_likelihood_zobs(
self._z_obs[i], zobs_pred, sigma_v)
return ll + jnp.log(self._f_simps(ptilde) / pnorm), None return ll + jnp.log(self._f_simps(ptilde) / pnorm), None
@ -1128,6 +1253,9 @@ class TF_PV_validation_model(BaseFlowValidationModel):
""" """
return (self._e2_mag + b**2 * self._e2_eta + e_mu_intrinsic**2) return (self._e2_mag + b**2 * self._e2_eta + e_mu_intrinsic**2)
def predict_zcosmo_from_calibration(self, **kwargs):
raise NotImplementedError("This method is not implemented yet.")
def predict_zobs_single(self, Vext_x, Vext_y, Vext_z, alpha, beta, def predict_zobs_single(self, Vext_x, Vext_y, Vext_z, alpha, beta,
e_mu_intrinsic, a, b, **kwargs): e_mu_intrinsic, a, b, **kwargs):
""" """
@ -1214,7 +1342,8 @@ class TF_PV_validation_model(BaseFlowValidationModel):
# Calculate p(z_obs) and multiply it by p(r) # Calculate p(z_obs) and multiply it by p(r)
zobs_pred = self._f_zobs(beta, Vext_rad[i], self._los_velocity[i]) zobs_pred = self._f_zobs(beta, Vext_rad[i], self._los_velocity[i])
ptilde *= calculate_ll_zobs(self._z_obs[i], zobs_pred, sigma_v) ptilde *= calculate_likelihood_zobs(
self._z_obs[i], zobs_pred, sigma_v)
return ll + jnp.log(self._f_simps(ptilde) / pnorm), None return ll + jnp.log(self._f_simps(ptilde) / pnorm), None
@ -1475,3 +1604,213 @@ def optimize_model_with_jackknife(loader, k, n_splits=5, sample_alpha=True,
loader.reset_mask() loader.reset_mask()
return samples, stats, fmin, logz, bic return samples, stats, fmin, logz, bic
###############################################################################
# Predicting z_cosmo from z_obs #
###############################################################################
def _posterior_element(r, beta, Vext_radial, los_velocity, Omega_m, zobs,
sigma_v, alpha, dVdOmega, los_density):
"""
Helper function function to compute the unnormalized posterior in
`Observed2CosmologicalRedshift`.
"""
zobs_pred = predict_zobs(r, beta, Vext_radial, los_velocity, Omega_m)
likelihood = calculate_likelihood_zobs(zobs, zobs_pred, sigma_v)
prior = dVdOmega * los_density**alpha
return likelihood * prior
class BaseObserved2CosmologicalRedshift(ABC):
"""Base class for `Observed2CosmologicalRedshift`."""
def __init__(self, calibration_samples, r_xrange):
dt = jnp.float32
# Check calibration samples input.
for i, key in enumerate(calibration_samples.keys()):
x = calibration_samples[key]
if not isinstance(x, (np.ndarray, jnp.ndarray)):
raise ValueError(f"Calibration sample {x} must be an array.")
if x.ndim != 1:
raise ValueError(f"Calibration samples {x} must be 1D.")
if i == 0:
ncalibratrion = len(x)
if len(x) != ncalibratrion:
raise ValueError("Calibration samples do not have the same length.") # noqa
# Enforce the same data type.
calibration_samples[key] = jnp.asarray(x, dtype=dt)
if "alpha" not in calibration_samples:
calibration_samples["alpha"] = jnp.ones(ncalibratrion, dtype=dt)
if "beta" not in calibration_samples:
calibration_samples["beta"] = jnp.ones(ncalibratrion, dtype=dt)
# Get the stepsize, we need it to be constant for Simpson's rule.
dr = np.diff(r_xrange)
if not np.all(np.isclose(dr, dr[0], atol=1e-5)):
raise ValueError("The radial step size must be constant.")
dr = dr[0]
self._calibration_samples = calibration_samples
self._ncalibration_samples = ncalibratrion
# It is best to JIT compile the functions right here.
self._vmap_simps = jit(vmap(lambda y: simps(y, dr)))
axs = (0, None, None, 0, None, None, None, None, 0, 0)
self._vmap_posterior_element = vmap(_posterior_element, in_axes=axs)
self._vmap_posterior_element = jit(self._vmap_posterior_element)
self._simps = jit(lambda y: simps(y, dr))
def get_calibration_samples(self, key):
"""
Get calibration samples for a given key.
Parameters
----------
key : str
Key of the calibration samples.
Returns
-------
1-dimensional array
"""
if key not in self._calibration_samples:
raise ValueError(f"Key `{key}` not found in calibration samples. Available keys are: `{self.calibration_keys}`.") # noqa
return self._calibration_samples[key]
@property
def ncalibration_samples(self):
"""
Number of calibration samples.
Returns
-------
int
"""
return self._ncalibration_samples
@property
def calibration_keys(self):
"""
Calibration sample keys.
Returns
-------
list of str
"""
return list(self._calibration_samples.keys())
class Observed2CosmologicalRedshift(BaseObserved2CosmologicalRedshift):
"""
Model to predict the cosmological redshift from the observed redshift.
Parameters
----------
calibration_samples : dict
Dictionary of flow calibration samples (`alpha`, `beta`, `Vext_x`,
`Vext_y`, `Vext_z`, `sigma_v`, ...).
r_xrange : 1-dimensional array
Radial comoving distances where the fields are interpolated for each
object.
Omega_m : float
Matter density parameter.
"""
def __init__(self, calibration_samples, r_xrange, Omega_m):
super().__init__(calibration_samples, r_xrange)
self._r_xrange = jnp.asarray(r_xrange, dtype=jnp.float32)
self._zcos_xrange = dist2redshift(self._r_xrange, Omega_m)
self._Omega_m = Omega_m
# Comoving volume element with some arbitrary normalization
dVdOmega = gradient_redshift2dist(self._zcos_xrange, Omega_m)
# TODO: Decide about the presence of this correction.
dVdOmega *= self._r_xrange**2
self._dVdOmega = dVdOmega / jnp.mean(dVdOmega)
def posterior_mean_std(self, x, px):
"""
Calculate the mean and standard deviation of a 1-dimensional PDF.
Assumes that the PDF is already normalized. Assumes that the PDF
spacing is that of `r_xrange` which is inferred when initializing this
class.
Parameters
----------
x : 1-dimensional array
Values at which the PDF is evaluated. Note that the PDF must be
normalized.
px : 1-dimensional array
PDF values.
dx
Returns
-------
mu, std : floats
"""
mu = self._simps(x * px)
std = (self._simps(x**2 * px) - mu**2)**0.5
return mu, std
def posterior_zcosmo(self, zobs, RA, dec, los_density, los_velocity,
extra_sigma_v=None, verbose=True):
"""
Calculate `p(z_cosmo | z_CMB, calibration)` for a single object.
Parameters
----------
zobs : float
Observed redshift.
RA, dec : float
Right ascension and declination in radians.
los_density : 1-dimensional array
LOS density field.
los_velocity : 1-dimensional array
LOS radial velocity field.
extra_sigma_v : float, optional
Any additional velocity uncertainty.
verbose : bool, optional
Verbosity flag.
Returns
-------
zcosmo : 1-dimensional array
Cosmological redshift at which the PDF is evaluated.
posterior : 1-dimensional array
Posterior PDF.
"""
Vext_radial = project_Vext(
self.get_calibration_samples("Vext_x"),
self.get_calibration_samples("Vext_y"),
self.get_calibration_samples("Vext_z"),
RA, dec)
alpha = self.get_calibration_samples("alpha")
beta = self.get_calibration_samples("beta")
sigma_v = self.get_calibration_samples("sigma_v")
if extra_sigma_v is not None:
sigma_v = jnp.sqrt(sigma_v**2 + extra_sigma_v**2)
posterior = np.zeros((self.ncalibration_samples, len(self._r_xrange)),
dtype=np.float32)
for i in trange(self.ncalibration_samples, desc="Marginalizing",
disable=not verbose):
posterior[i] = self._vmap_posterior_element(
self._r_xrange, beta[i], Vext_radial[i], los_velocity,
self._Omega_m, zobs, sigma_v[i], alpha[i], self._dVdOmega,
los_density)
# Normalize the posterior for each flow sample and then stack them.
posterior /= self._vmap_simps(posterior).reshape(-1, 1)
posterior = jnp.nanmean(posterior, axis=0)
return self._zcos_xrange, posterior

File diff suppressed because one or more lines are too long