diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index cbb56dd..b573309 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -45,13 +45,13 @@ H0 = 100 # km / s / Mpc # JAX Flow model # ############################################################################### -def dist2redshift(dist, Omega_m): +def dist2redshift(dist, Omega_m, h=1.): """ Convert comoving distance to cosmological redshift if the Universe is flat and z << 1. """ eta = 3 * Omega_m / 2 - return 1 / eta * (1 - (1 - 2 * H0 * dist / SPEED_OF_LIGHT * eta)**0.5) + return 1 / eta * (1 - (1 - 2 * 100 * h * dist / SPEED_OF_LIGHT * eta)**0.5) def redshift2dist(z, Omega_m): @@ -366,11 +366,6 @@ def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, } -############################################################################### -# PV calibration model # -############################################################################### - - def sample_gaussian_hyperprior(param, name, xmin, xmax): """Sample MNR Gaussian hyperprior mean and standard deviation.""" mean = sample(f"{param}_mean_{name}", Uniform(xmin, xmax)) @@ -378,9 +373,15 @@ def sample_gaussian_hyperprior(param, name, xmin, xmax): return mean, std +############################################################################### +# PV calibration model without absolute calibration # +############################################################################### + + class PV_LogLikelihood(BaseFlowValidationModel): """ - Peculiar velocity validation model log-likelihood. + Peculiar velocity validation model log-likelihood with numerical + integration of the true distances. Parameters ---------- @@ -408,11 +409,15 @@ class PV_LogLikelihood(BaseFlowValidationModel): Catalogue kind, either "TFR", "SN", or "simple". name : str Name of the catalogue. + with_num_dist_marginalisation : bool, optional + Whether to use numerical distance marginalisation, in which case + the tracers cannot be coupled by a covariance matrix. By default + `True`. """ def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs, calibration_params, abs_calibration_params, mag_selection, - r_xrange, Omega_m, kind, name): + r_xrange, Omega_m, kind, name, with_num_dist_marginalisation): if e_zobs is not None: e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) else: @@ -433,8 +438,12 @@ class PV_LogLikelihood(BaseFlowValidationModel): self.kind = kind self.name = name self.Omega_m = Omega_m + self.with_num_dist_marginalisation = with_num_dist_marginalisation self.norm = - self.ndata * jnp.log(self.num_sims) + # TODO: Somewhere here prepare the interpolators in case of no + # numerical marginalisation. + if mag_selection is not None: self.mag_selection_kind = mag_selection["kind"] @@ -485,6 +494,9 @@ class PV_LogLikelihood(BaseFlowValidationModel): Vmono = field_calibration_params["Vmono"] Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec) + # ------------------------------------------------------------ + # 1. Sample true observables and obtain the distance estimate + # ------------------------------------------------------------ e_mu = distmod_params["e_mu"] if self.kind == "SN": mag_cal = distmod_params["mag_cal"] @@ -532,10 +544,6 @@ class PV_LogLikelihood(BaseFlowValidationModel): mu = distmod_SN( mag_true, x1_true, c_true, mag_cal, alpha_cal, beta_cal) - - if field_calibration_params["sample_h"]: - raise NotImplementedError("H0 for SN not implemented.") - elif self.kind == "TFR": a = distmod_params["a"] b = distmod_params["b"] @@ -605,11 +613,6 @@ class PV_LogLikelihood(BaseFlowValidationModel): e2_mu = jnp.ones_like(mag_true) * e_mu**2 mu = distmod_TFR(mag_true, eta_true, a, b, c) - - if field_calibration_params["sample_h"]: - raise NotImplementedError("H0 for TFR not implemented.") - # mu -= 5 * jnp.log10(field_calibration_params["h"]) - elif self.kind == "simple": dmu = distmod_params["dmu"] @@ -628,61 +631,73 @@ class PV_LogLikelihood(BaseFlowValidationModel): e2_mu = jnp.ones_like(mag_true) * e_mu**2 mu = mu_true + dmu - - if field_calibration_params["sample_h"]: - raise NotImplementedError("H0 for simple not implemented.") else: raise ValueError(f"Unknown kind: `{self.kind}`.") - # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) - log_ptilde = log_ptilde_wo_bias( - self.mu_xrange[None, :], mu[:, None], e2_mu[:, None], - self.log_r2_xrange[None, :]) + # h = field_calibration_params["h"] + # ---------------------------------------------------------------- + # 2. Log-likelihood of the true distance and observed redshifts. + # The marginalisation of the true distance can be done numerically. + # ---------------------------------------------------------------- + if self.with_num_dist_marginalisation: - # Inhomogeneous Malmquist bias. Shape is (n_sims, ndata, nxrange) - alpha = distmod_params["alpha"] - log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density + if field_calibration_params["sample_h"]: + raise NotImplementedError("Sampling of h not implemented.") + # Rescale the grid to account for the sampled H0. For distance + # modulus going from Mpc / h to Mpc implies larger numerical + # values, so there has to be a minus sign since h < 1. + # mu_xrange = self.mu_xrange - 5 * jnp.log(h) - ptilde = jnp.exp(log_ptilde) + # The redshift should also be boosted since now the object are + # further away? - # Normalization of p(r). Shape is (n_sims, ndata) - pnorm = simpson(ptilde, x=self.r_xrange, axis=-1) + # Actually, the redshift ought to remain the same? + else: + mu_xrange = self.mu_xrange - # Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange) - vrad = field_calibration_params["beta"] * self.los_velocity - vrad += (Vext_rad[None, :, None] + Vmono) - zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT) - zobs -= 1. + # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) + log_ptilde = log_ptilde_wo_bias( + mu_xrange[None, :], mu[:, None], e2_mu[:, None], + self.log_r2_xrange[None, :]) - # Shape remains (n_sims, ndata, nxrange) - ptilde *= likelihood_zobs( - self.z_obs[None, :, None], zobs, e2_cz[None, :, None]) + # Inhomogeneous Malmquist bias. Shape: (nsims, ndata, nxrange) + alpha = distmod_params["alpha"] + log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density - if self.with_absolute_calibration: - raise NotImplementedError("Absolute calibration not implemented.") - # Absolute calibration likelihood, the shape is now - # (ndata_with_calibration, ncalib, nxrange) - # ll_calibration = normal_logpdf( - # self.mu_xrange[None, None, :], - # self.calibration_distmod[..., None], - # self.calibration_edistmod[..., None]) + ptilde = jnp.exp(log_ptilde) - # # Average the likelihood over the calibration points. The shape - # is - # # now (ndata, nxrange) - # ll_calibration = logsumexp( - # jnp.nan_to_num(ll_calibration, nan=-jnp.inf), axis=1) - # # This is the normalisation because we want the *average*. - # ll_calibration -= self.log_length_calibration[:, None] + # Normalization of p(r). Shape: (nsims, ndata) + pnorm = simpson(ptilde, x=self.r_xrange, axis=-1) - # ptilde = ptilde.at[:, self.data_with_calibration, :]. - # multiply(jnp.exp(ll_calibration)) + # Calculate z_obs at each distance. Shape: (nsims, ndata, nxrange) + vrad = field_calibration_params["beta"] * self.los_velocity + vrad += (Vext_rad[None, :, None] + Vmono) + zobs = 1 + self.z_xrange[None, None, :] + zobs *= 1 + vrad / SPEED_OF_LIGHT + zobs -= 1. - # Integrate over the radial distance. Shape is (n_sims, ndata) - ll = jnp.log(simpson(ptilde, x=self.r_xrange, axis=-1)) - ll -= jnp.log(pnorm) + # Shape remains (nsims, ndata, nxrange) + ptilde *= likelihood_zobs( + self.z_obs[None, :, None], zobs, e2_cz[None, :, None]) - return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm + if self.with_absolute_calibration: + raise NotImplementedError( + "Absolute calibration not implemented for this model. " + "Use `PV_LogLikelihood_NoDistMarg` instead.") + + # Integrate over the radial distance. Shape: (nsims, ndata) + ll = jnp.log(simpson(ptilde, x=self.r_xrange, axis=-1)) + ll -= jnp.log(pnorm) + + return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm + else: + raise NotImplementedError( + "No distance marginalisation not implemented yet.") + + +############################################################################### +# Combining several catalogues # +############################################################################### def PV_validation_model(models, distmod_hyperparams_per_model,