Add various cleanups of the code

This commit is contained in:
rstiskalek 2024-09-19 14:04:48 +01:00
parent 140d5ea40e
commit 3f4a63d3de

View file

@ -648,7 +648,8 @@ def sample_gaussian_hyperprior(param, name, xmin, xmax):
class PV_LogLikelihood(BaseFlowValidationModel): class PV_LogLikelihood(BaseFlowValidationModel):
""" """
Peculiar velocity validation model log-likelihood. Peculiar velocity validation model log-likelihood with numerical
integration of the true distances.
Parameters Parameters
---------- ----------
@ -676,11 +677,15 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Catalogue kind, either "TFR", "SN", or "simple". Catalogue kind, either "TFR", "SN", or "simple".
name : str name : str
Name of the catalogue. Name of the catalogue.
with_num_dist_marginalisation : bool, optional
Whether to use numerical distance marginalisation, in which case
the tracers cannot be coupled by a covariance matrix. By default
`True`.
""" """
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs, def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
calibration_params, abs_calibration_params, mag_selection, calibration_params, abs_calibration_params, mag_selection,
r_xrange, Omega_m, kind, name): r_xrange, Omega_m, kind, name, with_num_dist_marginalisation):
if e_zobs is not None: if e_zobs is not None:
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
else: else:
@ -701,8 +706,12 @@ class PV_LogLikelihood(BaseFlowValidationModel):
self.kind = kind self.kind = kind
self.name = name self.name = name
self.Omega_m = Omega_m self.Omega_m = Omega_m
self.with_num_dist_marginalisation = with_num_dist_marginalisation
self.norm = - self.ndata * jnp.log(self.num_sims) self.norm = - self.ndata * jnp.log(self.num_sims)
# TODO: Somewhere here prepare the interpolators in case of no
# numerical marginalisation.
if mag_selection is not None: if mag_selection is not None:
self.mag_selection_kind = mag_selection["kind"] self.mag_selection_kind = mag_selection["kind"]
@ -753,6 +762,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Vmono = field_calibration_params["Vmono"] Vmono = field_calibration_params["Vmono"]
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec) Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
# ------------------------------------------------------------
# 1. Sample true observables and obtain the distance estimate
# ------------------------------------------------------------
e_mu = distmod_params["e_mu"] e_mu = distmod_params["e_mu"]
if self.kind == "SN": if self.kind == "SN":
mag_cal = distmod_params["mag_cal"] mag_cal = distmod_params["mag_cal"]
@ -800,10 +812,6 @@ class PV_LogLikelihood(BaseFlowValidationModel):
mu = distmod_SN( mu = distmod_SN(
mag_true, x1_true, c_true, mag_cal, alpha_cal, beta_cal) mag_true, x1_true, c_true, mag_cal, alpha_cal, beta_cal)
if field_calibration_params["sample_h"]:
raise NotImplementedError("H0 for SN not implemented.")
elif self.kind == "TFR": elif self.kind == "TFR":
a = distmod_params["a"] a = distmod_params["a"]
b = distmod_params["b"] b = distmod_params["b"]
@ -873,11 +881,6 @@ class PV_LogLikelihood(BaseFlowValidationModel):
e2_mu = jnp.ones_like(mag_true) * e_mu**2 e2_mu = jnp.ones_like(mag_true) * e_mu**2
mu = distmod_TFR(mag_true, eta_true, a, b, c) mu = distmod_TFR(mag_true, eta_true, a, b, c)
if field_calibration_params["sample_h"]:
raise NotImplementedError("H0 for TFR not implemented.")
# mu -= 5 * jnp.log10(field_calibration_params["h"])
elif self.kind == "simple": elif self.kind == "simple":
dmu = distmod_params["dmu"] dmu = distmod_params["dmu"]
@ -896,68 +899,52 @@ class PV_LogLikelihood(BaseFlowValidationModel):
e2_mu = jnp.ones_like(mag_true) * e_mu**2 e2_mu = jnp.ones_like(mag_true) * e_mu**2
mu = mu_true + dmu mu = mu_true + dmu
if field_calibration_params["sample_h"]:
raise NotImplementedError("H0 for simple not implemented.")
else: else:
raise ValueError(f"Unknown kind: `{self.kind}`.") raise ValueError(f"Unknown kind: `{self.kind}`.")
# ----------------------------------------------------------------
# 2. Log-likelihood of the true distance and observed redshifts.
# The marginalisation of the true distance can be done numerically.
# ----------------------------------------------------------------
if self.with_num_dist_marginalisation:
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
log_ptilde = log_ptilde_wo_bias( log_ptilde = log_ptilde_wo_bias(
self.mu_xrange[None, :], mu[:, None], e2_mu[:, None], self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
self.log_r2_xrange[None, :]) self.log_r2_xrange[None, :])
# Inhomogeneous Malmquist bias. Shape is (n_sims, ndata, nxrange) # Inhomogeneous Malmquist bias. Shape: (nsims, ndata, nxrange)
alpha = distmod_params["alpha"] alpha = distmod_params["alpha"]
log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density
ptilde = jnp.exp(log_ptilde) ptilde = jnp.exp(log_ptilde)
# Normalization of p(r). Shape is (n_sims, ndata) # Normalization of p(r). Shape: (nsims, ndata)
pnorm = simpson(ptilde, x=self.r_xrange, axis=-1) pnorm = simpson(ptilde, x=self.r_xrange, axis=-1)
# Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange) # Calculate z_obs at each distance. Shape: (nsims, ndata, nxrange)
vrad = field_calibration_params["beta"] * self.los_velocity vrad = field_calibration_params["beta"] * self.los_velocity
vrad += (Vext_rad[None, :, None] + Vmono) vrad += (Vext_rad[None, :, None] + Vmono)
zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) zobs = 1 + self.z_xrange[None, None, :]
zobs *= 1 + vrad / SPEED_OF_LIGHT
zobs -= 1. zobs -= 1.
# Shape remains (n_sims, ndata, nxrange) # Shape remains (nsims, ndata, nxrange)
ptilde *= likelihood_zobs( ptilde *= likelihood_zobs(
self.z_obs[None, :, None], zobs, e2_cz[None, :, None]) self.z_obs[None, :, None], zobs, e2_cz[None, :, None])
if self.with_absolute_calibration: if self.with_absolute_calibration:
raise NotImplementedError("Absolute calibration not implemented.") raise NotImplementedError(
# Absolute calibration likelihood, the shape is now "Absolute calibration not implemented for this model. "
# (ndata_with_calibration, ncalib, nxrange) "Use `PV_LogLikelihood_NoDistMarg` instead.")
# ll_calibration = normal_logpdf(
# self.mu_xrange[None, None, :],
# self.calibration_distmod[..., None],
# self.calibration_edistmod[..., None])
# # Average the likelihood over the calibration points. The shape # Integrate over the radial distance. Shape: (nsims, ndata)
# is
# # now (ndata, nxrange)
# ll_calibration = logsumexp(
# jnp.nan_to_num(ll_calibration, nan=-jnp.inf), axis=1)
# # This is the normalisation because we want the *average*.
# ll_calibration -= self.log_length_calibration[:, None]
# ptilde = ptilde.at[:, self.data_with_calibration, :].
# multiply(jnp.exp(ll_calibration))
# Integrate over the radial distance. Shape is (n_sims, ndata)
ll = jnp.log(simpson(ptilde, x=self.r_xrange, axis=-1)) ll = jnp.log(simpson(ptilde, x=self.r_xrange, axis=-1))
ll -= jnp.log(pnorm) ll -= jnp.log(pnorm)
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
else:
raise NotImplementedError(
############################################################################### "No distance marginalisation not implemented yet.")
# PV calibration model with absolute calibration #
###############################################################################
############################################################################### ###############################################################################