diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index c83fb0a..d2443e6 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -27,8 +27,10 @@ import numpy as np from astropy import units as u from astropy.coordinates import SkyCoord, angular_separation from astropy.cosmology import FlatLambdaCDM, z_at_value +from interpax import interp1d from jax import jit from jax import numpy as jnp +from jax import vmap from jax.scipy.special import erf, logsumexp from numpyro import factor, plate, sample from numpyro.distributions import MultivariateNormal, Normal, Uniform @@ -88,6 +90,39 @@ def distmod2dist(mu, Om0): return (-0.000301) + (term1 * (term2 - (term3 - term4))) +def distmod2dist_gradient(mu, Om0): + """ + Calculate the derivative of comoving distance in `Mpc / h` with respect to + the distance modulus. The expression is valid for a flat universe over the + range of 0.00001 < z < 0.1. + """ + term1 = jnp.exp((0.443288 * mu) + (-14.286531)) + dterm1 = 0.443288 * term1 + + term2 = (0.506973 * mu) + 12.954633 + dterm2 = 0.506973 + + term3 = ((0.028134 * mu)**(((0.684713 * mu) + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) # noqa + ln_base = jnp.log(0.028134) + jnp.log(mu) + exponent = 0.835733 * mu + 1.235158 * Om0 - jnp.exp(0.072229 * mu) + exponent_derivative = 0.835733 - 0.072229 * jnp.exp(0.072229 * mu) + dterm3 = term3 * ((1 / mu) * exponent + exponent_derivative * ln_base) + + term4 = (-0.045160) * mu + dterm4 = -0.045160 + + return (dterm1 * (term2 - (term3 - term4)) + + term1 * (dterm2 - (dterm3 - dterm4))) + + +def distmod2redshift(mu, Om0): + """ + Convert distance modulus to redshift, assuming `h = 1`. The expression is + valid for a flat universe over the range of 0.00001 < z < 0.1. + """ + return jnp.exp(((0.461108 * mu) - ((0.022187 * Om0) + (((0.022347 * mu)** (12.631788 - ((-6.708757) * Om0))) + 19.529852)))) # noqa + + def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians): """Project the external velocity vector onto the line of sight.""" cos_dec = jnp.cos(dec_radians) @@ -150,6 +185,37 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax): return normal_logpdf(x, loc, scale) - jnp.log(norm) +############################################################################### +# LOS interpolation # +############################################################################### + + +def interpolate_los(r, los, rgrid, method="cubic"): + """ + Interpolate the LOS field at a given radial distance. + + Parameters + ---------- + r : 1-dimensional array of shape `(n_gal, )` + Radial distances at which to interpolate the LOS field. + los : 3-dimensional array of shape `(n_sims, n_gal, n_steps)` + LOS field. + rmin, rmax : float + Minimum and maximum radial distances in the data. + order : int, optional + The order of the interpolation. Default is 1, can be 0. + + Returns + ------- + 2-dimensional array of shape `(n_sims, n_gal)` + """ + # Vectorize over the inner loop (ngal) first, then the outer loop (nsim) + def f(rn, los_row): + return interp1d(rn, rgrid, los_row, method=method) + + return vmap(vmap(f, in_axes=(0, 0)), in_axes=(None, 0))(r, los) + + ############################################################################### # Base flow validation # ############################################################################### @@ -291,6 +357,12 @@ class BaseFlowValidationModel(ABC): return self._los_velocity + def log_los_density_at_r(self, r): + return interpolate_los(r, self.log_los_density(), self.r_xrange, ) + + def los_velocity_at_r(self, r): + return interpolate_los(r, self.los_velocity(), self.r_xrange, ) + @abstractmethod def __call__(self, **kwargs): pass @@ -514,16 +586,16 @@ class PV_LogLikelihood(BaseFlowValidationModel): Name of the catalogue. void_kwargs : dict, optional Void data parameters. If `None` the data is not void data. - with_num_dist_marginalisation : bool, optional - Whether to use numerical distance marginalisation, in which case - the tracers cannot be coupled by a covariance matrix. By default - `True`. + wo_num_dist_marginalisation : bool, optional + Whether to directly sample the distance without numerical + marginalisation. in which case the tracers can be coupled by a + covariance matrix. By default `False`. """ def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs, calibration_params, abs_calibration_params, mag_selection, r_xrange, Omega_m, kind, name, void_kwargs=None, - with_num_dist_marginalisation=True): + wo_num_dist_marginalisation=False): if e_zobs is not None: e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) else: @@ -549,7 +621,7 @@ class PV_LogLikelihood(BaseFlowValidationModel): values += [jnp.log(los_density), los_velocity] # Density required only if not numerically marginalising. - if not with_num_dist_marginalisation: + if not wo_num_dist_marginalisation: names += ["_los_density"] values += [los_density] @@ -561,12 +633,9 @@ class PV_LogLikelihood(BaseFlowValidationModel): self.kind = kind self.name = name self.Omega_m = Omega_m - self.with_num_dist_marginalisation = with_num_dist_marginalisation + self.wo_num_dist_marginalisation = wo_num_dist_marginalisation self.norm = - self.ndata * jnp.log(self.num_sims) - # TODO: Somewhere here prepare the interpolators in case of no - # numerical marginalisation. - if mag_selection is not None: self.mag_selection_kind = mag_selection["kind"] @@ -772,25 +841,16 @@ class PV_LogLikelihood(BaseFlowValidationModel): # 2. Log-likelihood of the true distance and observed redshifts. # The marginalisation of the true distance can be done numerically. # ---------------------------------------------------------------- - if self.with_num_dist_marginalisation: + if not self.wo_num_dist_marginalisation: if field_calibration_params["sample_h"]: - raise NotImplementedError("Sampling of h not implemented.") - # Rescale the grid to account for the sampled H0. For distance - # modulus going from Mpc / h to Mpc implies larger numerical - # values, so there has to be a minus sign since h < 1. - # mu_xrange = self.mu_xrange - 5 * jnp.log(h) - - # The redshift should also be boosted since now the object are - # further away? - - # Actually, the redshift ought to remain the same? - else: - mu_xrange = self.mu_xrange + raise NotImplementedError( + "Sampling of 'h' is not supported if numerically " + "marginalising the true distance.") # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) log_ptilde = log_ptilde_wo_bias( - mu_xrange[None, :], mu[:, None], e2_mu[:, None], + self.mu_xrange[None, :], mu[:, None], e2_mu[:, None], self.log_r2_xrange[None, :]) if self.is_void_data: @@ -832,56 +892,52 @@ class PV_LogLikelihood(BaseFlowValidationModel): return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm else: if field_calibration_params["sample_h"]: - raise NotImplementedError("Sampling of h not implemented.") - - raise NotImplementedError( - "Sampling of distance is not implemented. Work in progress.") + raise NotImplementedError( + "Sampling of h is not yet implemented.") e_mu = jnp.sqrt(e2_mu) # True distance modulus, shape is `(n_data)`` with plate("plate_mu", self.ndata): mu_true = sample("mu", Normal(mu, e_mu)) - # True distance, shape is `(n_data)`` + # True distance and redshift, shape is `(n_data)`. r_true = distmod2dist(mu_true, self.Omega_m) - # TODO: - z_true = None + z_true = distmod2redshift(mu_true, self.Omega_m) if self.is_void_data: raise NotImplementedError( "Void data not implemented yet for distance sampling.") else: - # grid log(density), shape is `(n_sims, n_data, n_rad)` - log_los_density_grid = self.los_density() - - # TODO: Need to add the interpolators for these + # Grid log(density), shape is `(n_sims, n_data, n_rad)` + log_los_density_grid = self.log_los_density() # Densities and velocities at the true distances, shape is # `(n_sims, n_data)` - log_density = None - los_velocity = None + log_density = self.log_los_density_at_r(r_true) + los_velocity = self.los_velocity_at_r(r_true) alpha = distmod_params["alpha"] - # Check dimensions of all this - # Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)` pnorm = ( - self.log_r2_xrange[None, None, :] + + self.log_r2_xrange[None, None, :] + alpha * log_los_density_grid + normal_logpdf( self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa - pnorm = jnp.exp(pnorm) - - # Normalization of p(mu). Shape is now (nsims, ndata) + # Now integrate over the radial steps. Shape is `(nsims, ndata)`. + # No Jacobian here because I integrate over distance, not the + # distance modulus. pnorm = simpson(pnorm, x=self.r_xrange, axis=-1) - # TODO: There should be a Jacobian? + # Jacobian |dr / dmu|_(mu_true), shape is `(n_data)`. + jac = jnp.abs(distmod2dist_gradient(mu_true, self.Omega_m)) + # Calculate unnormalized log p(mu). Shape is (nsims, ndata) ll = ( - 2 * (jnp.log(r_true) - self.log_r2_xrange_mean)[None, :] + + jnp.log(jac)[None, :] + + (2 * jnp.log(r_true) - self.log_r2_xrange_mean)[None, :] + alpha * log_density - + normal_logpdf(mu_true, mu, e_mu)[None, :]) + ) # Subtract the normalization. Shape remains (nsims, ndata) ll -= jnp.log(pnorm) diff --git a/scripts/flow/flow_validation.py b/scripts/flow/flow_validation.py index b45bf8f..2d558ab 100644 --- a/scripts/flow/flow_validation.py +++ b/scripts/flow/flow_validation.py @@ -88,7 +88,7 @@ def print_variables(names, variables): def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs, - verbose=True): + wo_num_dist_marginalisation, verbose=True): """Load the data and create the NumPyro models.""" paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) folder = "/mnt/extraspace/rstiskalek/catalogs/" @@ -128,6 +128,7 @@ def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs, ksmooth=ARGS.ksmooth) models[i] = csiborgtools.flow.get_model( loader, mag_selection=mag_selection[i], void_kwargs=void_kwargs, + wo_num_dist_marginalisation=wo_num_dist_marginalisation, **get_model_kwargs) fprint(f"num. radial steps is {len(loader.rdist)}") @@ -299,10 +300,10 @@ if __name__ == "__main__": ########################################################################### # `None` means default behaviour - nsteps = 10_000 - nburn = 2_000 + nsteps = 1_000 + nburn = 1_000 zcmb_min = None - zcmb_max = 0.05 + zcmb_max = 0.04 nchains_harmonic = 10 num_epochs = 50 inference_method = "mike" @@ -313,8 +314,9 @@ if __name__ == "__main__": sample_Vmag_vax = False sample_Vmono = False sample_mag_dipole = False + wo_num_dist_marginalisation = True absolute_calibration = None - calculate_harmonic = False if inference_method == "bayes" else True + calculate_harmonic = (False if inference_method == "bayes" else True) and (not wo_num_dist_marginalisation) # noqa sample_h = True if absolute_calibration is not None else False fname_kwargs = {"inference_method": inference_method, @@ -341,6 +343,7 @@ if __name__ == "__main__": "num_epochs": num_epochs, "inference_method": inference_method, "sample_mag_dipole": sample_mag_dipole, + "wo_dist_marg": wo_num_dist_marginalisation, "absolute_calibration": absolute_calibration, "sample_h": sample_h, } @@ -420,7 +423,8 @@ if __name__ == "__main__": print(f"{'Current simulation:':<20} {i + 1} ({ksim}) out of {len(ksim_iterator)}.") # noqa fname_kwargs["nsim"] = ksim - models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs) + models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs, + wo_num_dist_marginalisation) model_kwargs = { "models": models, "field_calibration_hyperparams": calibration_hyperparams,