diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index 6494927..5916de4 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -648,7 +648,8 @@ def sample_gaussian_hyperprior(param, name, xmin, xmax): class PV_LogLikelihood(BaseFlowValidationModel): """ - Peculiar velocity validation model log-likelihood. + Peculiar velocity validation model log-likelihood with numerical + integration of the true distances. Parameters ---------- @@ -676,11 +677,15 @@ class PV_LogLikelihood(BaseFlowValidationModel): Catalogue kind, either "TFR", "SN", or "simple". name : str Name of the catalogue. + with_num_dist_marginalisation : bool, optional + Whether to use numerical distance marginalisation, in which case + the tracers cannot be coupled by a covariance matrix. By default + `True`. """ def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs, calibration_params, abs_calibration_params, mag_selection, - r_xrange, Omega_m, kind, name): + r_xrange, Omega_m, kind, name, with_num_dist_marginalisation): if e_zobs is not None: e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) else: @@ -701,8 +706,12 @@ class PV_LogLikelihood(BaseFlowValidationModel): self.kind = kind self.name = name self.Omega_m = Omega_m + self.with_num_dist_marginalisation = with_num_dist_marginalisation self.norm = - self.ndata * jnp.log(self.num_sims) + # TODO: Somewhere here prepare the interpolators in case of no + # numerical marginalisation. + if mag_selection is not None: self.mag_selection_kind = mag_selection["kind"] @@ -753,6 +762,9 @@ class PV_LogLikelihood(BaseFlowValidationModel): Vmono = field_calibration_params["Vmono"] Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec) + # ------------------------------------------------------------ + # 1. Sample true observables and obtain the distance estimate + # ------------------------------------------------------------ e_mu = distmod_params["e_mu"] if self.kind == "SN": mag_cal = distmod_params["mag_cal"] @@ -800,10 +812,6 @@ class PV_LogLikelihood(BaseFlowValidationModel): mu = distmod_SN( mag_true, x1_true, c_true, mag_cal, alpha_cal, beta_cal) - - if field_calibration_params["sample_h"]: - raise NotImplementedError("H0 for SN not implemented.") - elif self.kind == "TFR": a = distmod_params["a"] b = distmod_params["b"] @@ -873,11 +881,6 @@ class PV_LogLikelihood(BaseFlowValidationModel): e2_mu = jnp.ones_like(mag_true) * e_mu**2 mu = distmod_TFR(mag_true, eta_true, a, b, c) - - if field_calibration_params["sample_h"]: - raise NotImplementedError("H0 for TFR not implemented.") - # mu -= 5 * jnp.log10(field_calibration_params["h"]) - elif self.kind == "simple": dmu = distmod_params["dmu"] @@ -896,68 +899,52 @@ class PV_LogLikelihood(BaseFlowValidationModel): e2_mu = jnp.ones_like(mag_true) * e_mu**2 mu = mu_true + dmu - - if field_calibration_params["sample_h"]: - raise NotImplementedError("H0 for simple not implemented.") else: raise ValueError(f"Unknown kind: `{self.kind}`.") - # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) - log_ptilde = log_ptilde_wo_bias( - self.mu_xrange[None, :], mu[:, None], e2_mu[:, None], - self.log_r2_xrange[None, :]) + # ---------------------------------------------------------------- + # 2. Log-likelihood of the true distance and observed redshifts. + # The marginalisation of the true distance can be done numerically. + # ---------------------------------------------------------------- + if self.with_num_dist_marginalisation: + # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) + log_ptilde = log_ptilde_wo_bias( + self.mu_xrange[None, :], mu[:, None], e2_mu[:, None], + self.log_r2_xrange[None, :]) - # Inhomogeneous Malmquist bias. Shape is (n_sims, ndata, nxrange) - alpha = distmod_params["alpha"] - log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density + # Inhomogeneous Malmquist bias. Shape: (nsims, ndata, nxrange) + alpha = distmod_params["alpha"] + log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density - ptilde = jnp.exp(log_ptilde) + ptilde = jnp.exp(log_ptilde) - # Normalization of p(r). Shape is (n_sims, ndata) - pnorm = simpson(ptilde, x=self.r_xrange, axis=-1) + # Normalization of p(r). Shape: (nsims, ndata) + pnorm = simpson(ptilde, x=self.r_xrange, axis=-1) - # Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange) - vrad = field_calibration_params["beta"] * self.los_velocity - vrad += (Vext_rad[None, :, None] + Vmono) - zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - zobs -= 1. + # Calculate z_obs at each distance. Shape: (nsims, ndata, nxrange) + vrad = field_calibration_params["beta"] * self.los_velocity + vrad += (Vext_rad[None, :, None] + Vmono) + zobs = 1 + self.z_xrange[None, None, :] + zobs *= 1 + vrad / SPEED_OF_LIGHT + zobs -= 1. - # Shape remains (n_sims, ndata, nxrange) - ptilde *= likelihood_zobs( - self.z_obs[None, :, None], zobs, e2_cz[None, :, None]) + # Shape remains (nsims, ndata, nxrange) + ptilde *= likelihood_zobs( + self.z_obs[None, :, None], zobs, e2_cz[None, :, None]) - if self.with_absolute_calibration: - raise NotImplementedError("Absolute calibration not implemented.") - # Absolute calibration likelihood, the shape is now - # (ndata_with_calibration, ncalib, nxrange) - # ll_calibration = normal_logpdf( - # self.mu_xrange[None, None, :], - # self.calibration_distmod[..., None], - # self.calibration_edistmod[..., None]) - - # # Average the likelihood over the calibration points. The shape - # is - # # now (ndata, nxrange) - # ll_calibration = logsumexp( - # jnp.nan_to_num(ll_calibration, nan=-jnp.inf), axis=1) - # # This is the normalisation because we want the *average*. - # ll_calibration -= self.log_length_calibration[:, None] - - # ptilde = ptilde.at[:, self.data_with_calibration, :]. - # multiply(jnp.exp(ll_calibration)) - - # Integrate over the radial distance. Shape is (n_sims, ndata) - ll = jnp.log(simpson(ptilde, x=self.r_xrange, axis=-1)) - ll -= jnp.log(pnorm) - - return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm - - -############################################################################### -# PV calibration model with absolute calibration # -############################################################################### + if self.with_absolute_calibration: + raise NotImplementedError( + "Absolute calibration not implemented for this model. " + "Use `PV_LogLikelihood_NoDistMarg` instead.") + # Integrate over the radial distance. Shape: (nsims, ndata) + ll = jnp.log(simpson(ptilde, x=self.r_xrange, axis=-1)) + ll -= jnp.log(pnorm) + return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm + else: + raise NotImplementedError( + "No distance marginalisation not implemented yet.") ###############################################################################