diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index 85e1d46..757f0bf 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -74,6 +74,20 @@ def gradient_redshift2dist(z, Omega_m): return SPEED_OF_LIGHT / H0 * (1 - z * (1 + q0)) +def distmod2dist(mu, Om0): + """ + Convert distance modulus to distance in `Mpc / h`. The expression is valid + for a flat universe over the range of 0.00001 < z < 0.1. + """ + term1 = jnp.exp((0.443288 * mu) + (-14.286531)) + term2 = (0.506973 * mu) + 12.954633 + term3 = ((0.028134 * mu) ** ( + ((0.684713 * mu) + + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) + term4 = (-0.045160) * mu + return (-0.000301) + (term1 * (term2 - (term3 - term4))) + + def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians): """Project the external velocity vector onto the line of sight.""" cos_dec = jnp.cos(dec_radians) @@ -113,6 +127,15 @@ def likelihood_zobs(zobs, zobs_pred, e2_cz): return jnp.exp(-0.5 * dcz**2 / e2_cz) / jnp.sqrt(2 * np.pi * e2_cz) +def log_likelihood_zobs(zobs, zobs_pred, e2_cz): + """ + Calculate the log-likelihood of the observed redshift given the predicted + redshift. Multiplies the redshifts by the speed of light. + """ + dcz = SPEED_OF_LIGHT * (zobs - zobs_pred) + return -0.5 * dcz**2 / e2_cz - 0.5 * jnp.log(2 * np.pi * e2_cz) + + def normal_logpdf(x, loc, scale): """Log of the normal probability density function.""" return (-0.5 * ((x - loc) / scale)**2 @@ -180,9 +203,12 @@ class BaseFlowValidationModel(ABC): r_xrange = jnp.asarray(r_xrange) r2_xrange = r_xrange**2 - r2_xrange /= r2_xrange.mean() + r2_xrange_mean = r2_xrange.mean() + r2_xrange /= r2_xrange_mean + self.r_xrange = r_xrange self.log_r2_xrange = jnp.log(r2_xrange) + self.log_r2_xrange_mean = jnp.log(r2_xrange_mean) # Require `zmin` < 0 because the first radial step is likely at 0. z_xrange = z_at_value( @@ -235,6 +261,13 @@ class BaseFlowValidationModel(ABC): return len(self.log_los_density()) + def los_density(self, **kwargs): + if self.is_void_data: + # Currently we have no densities for the void. + return jnp.ones((1, self.ndata, len(self.r_xrange))) + + return self._los_density + def log_los_density(self, **kwargs): if self.is_void_data: # Currently we have no densities for the void. @@ -491,11 +524,16 @@ class PV_LogLikelihood(BaseFlowValidationModel): names = ["RA", "dec", "z_obs", "e2_cz_obs"] values = [RA, dec, z_obs, e2_cz_obs] + # If ever start running out of memory, may be better not to store + # both the density and log_density if not self.is_void_data: names += ["_log_los_density", "_los_velocity"] values += [jnp.log(los_density), los_velocity] - # Set the void data + # Density required only if not numerically marginalising. + if not with_num_dist_marginalisation: + names += ["_los_density"] + values += [los_density] self._setattr_as_jax(names, values) self._set_calibration_params(calibration_params) @@ -691,13 +729,12 @@ class PV_LogLikelihood(BaseFlowValidationModel): if inference_method == "bayes": raise NotImplementedError("Bayes for simple not implemented.") else: - mu_true = self.mu if inference_method == "mike": e2_mu = e_mu**2 + self.e2_mu else: e2_mu = jnp.ones_like(mag_true) * e_mu**2 - mu = mu_true + dmu + mu = self.mu + dmu else: raise ValueError(f"Unknown kind: `{self.kind}`.") @@ -765,8 +802,74 @@ class PV_LogLikelihood(BaseFlowValidationModel): return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm else: - raise NotImplementedError( - "No distance marginalisation not implemented yet.") + if field_calibration_params["sample_h"]: + raise NotImplementedError("Sampling of h not implemented.") + + e_mu = jnp.sqrt(e2_mu) + # True distance modulus, shape is `(n_data)`` + with plate("plate_mu", self.ndata): + mu_true = sample("mu", Normal(mu, e_mu)) + + # True distance, shape is `(n_data)`` + r_true = distmod2dist(mu_true, self.Omega_m) + # TODO: + z_true = None + + if self.is_void_data: + raise NotImplementedError( + "Void data not implemented yet for distance sampling.") + else: + # grid log(density), shape is `(n_sims, n_data, n_rad)` + log_los_density_grid = self.los_density() + + # TODO: Need to add the interpolators for these + # Densities and velocities at the true distances, shape is + # `(n_sims, n_data)` + log_density = None + los_velocity = None + + alpha = distmod_params["alpha"] + + # Check dimensions of all this + + # Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)` + pnorm = ( + self.log_r2_xrange[None, None, :] + + alpha * log_los_density_grid + + normal_logpdf( + self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa + + pnorm = jnp.exp(pnorm) + + # Normalization of p(mu). Shape is now (nsims, ndata) + pnorm = simpson(pnorm, x=self.r_xrange, axis=-1) + + # TODO: There should be a Jacobian? + # Calculate unnormalized log p(mu). Shape is (nsims, ndata) + ll = ( + 2 * (jnp.log(r_true) - self.log_r2_xrange_mean)[None, :] + + alpha * log_density + + normal_logpdf(mu_true, mu, e_mu)[None, :]) + + # Subtract the normalization. Shape remains (nsims, ndata) + ll -= jnp.log(pnorm) + + # Calculate z_obs at the true distance. Shape: (nsims, ndata) + vrad = field_calibration_params["beta"] * los_velocity + vrad += (Vext_rad[None, :] + Vmono) + zobs = 1 + z_true[None, :] + zobs *= 1 + vrad / SPEED_OF_LIGHT + zobs -= 1. + + ll += log_likelihood_zobs( + self.z_obs[None, :], zobs, e2_cz[None, :]) + + if self.with_absolute_calibration: + raise NotImplementedError( + "Absolute calibration not implemented for this model. " + "Use `PV_LogLikelihood_NoDistMarg` instead.") + + return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm ###############################################################################