diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index b573309..85e1d46 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -25,6 +25,7 @@ from abc import ABC, abstractmethod import numpy as np from astropy import units as u +from astropy.coordinates import SkyCoord, angular_separation from astropy.cosmology import FlatLambdaCDM, z_at_value from jax import jit from jax import numpy as jnp @@ -37,6 +38,7 @@ from tqdm import trange from ..params import SPEED_OF_LIGHT from ..utils import fprint from .selection import toy_log_magnitude_selection +from .void_model import interpolate_void, load_void_data H0 = 100 # km / s / Mpc @@ -193,6 +195,33 @@ class BaseFlowValidationModel(ABC): self.z_xrange = jnp.asarray(z_xrange) self.mu_xrange = jnp.asarray(mu_xrange) + def _set_void_data(self, RA, dec, kind, h, order): + """Create the void interpolator.""" + # h is the MOND model value of local H0 to convert the radial grid to + # Mpc / h + rLG_grid, void_grid = load_void_data(kind) + void_grid = jnp.asarray(void_grid, dtype=jnp.float32) + rLG_grid = jnp.asarray(rLG_grid, dtype=jnp.float32) + + rLG_grid *= h + rLG_min, rLG_max = rLG_grid.min(), rLG_grid.max() + rgrid_min, rgrid_max = 0, 250 + fprint(f"setting radial grid from {rLG_min} to {rLG_max} Mpc.") + rgrid_max *= h + + # Get angular separation (in degrees) of each object from the model + # axis. + model_axis = SkyCoord(l=117, b=4, frame='galactic', unit='deg').icrs + coords = SkyCoord(ra=RA, dec=dec, unit='deg').icrs + + phi = angular_separation(coords.ra.rad, coords.dec.rad, + model_axis.ra.rad, model_axis.dec.rad) + phi = jnp.asarray(phi * 180 / np.pi, dtype=jnp.float32) + + self.void_interpolator = lambda rLG: interpolate_void( + rLG, self.r_xrange, phi, void_grid, rgrid_min, rgrid_max, + rLG_min, rLG_max, order) + @property def ndata(self): """Number of PV objects in the catalogue.""" @@ -201,7 +230,24 @@ class BaseFlowValidationModel(ABC): @property def num_sims(self): """Number of simulations.""" - return len(self.log_los_density) + if self.is_void_data: + return 1. + + return len(self.log_los_density()) + + def log_los_density(self, **kwargs): + if self.is_void_data: + # Currently we have no densities for the void. + return jnp.zeros((1, self.ndata, len(self.r_xrange))) + + return self._log_los_density + + def los_velocity(self, **kwargs): + if self.is_void_data: + # We want the shape to be `(1, n_objects, n_radial_steps)``. + return self.void_interpolator(kwargs["rLG"])[None, ...] + + return self._los_velocity @abstractmethod def __call__(self, **kwargs): @@ -331,7 +377,8 @@ def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max, def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, beta_max, sigma_v_min, sigma_v_max, h_min, h_max, - no_Vext, sample_Vmono, sample_beta, sample_h): + rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta, + sample_h, sample_rLG): """Sample the flow calibration.""" sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max)) @@ -357,12 +404,18 @@ def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min, else: h = 1.0 + if sample_rLG: + rLG = sample("rLG", Uniform(rLG_min, rLG_max)) + else: + rLG = None + return {"Vext": Vext, "Vmono": Vmono, "sigma_v": sigma_v, "beta": beta, "h": h, "sample_h": sample_h, + "rLG": rLG, } @@ -386,9 +439,9 @@ class PV_LogLikelihood(BaseFlowValidationModel): Parameters ---------- los_density : 3-dimensional array of shape (n_sims, n_objects, n_steps) - LOS density field. + LOS density field. Set to `None` if the data is void data. los_velocity : 3-dimensional array of shape (n_sims, n_objects, n_steps) - LOS radial velocity field. + LOS radial velocity field. Set to `None` if the data is void data. RA, dec : 1-dimensional arrays of shape (n_objects) Right ascension and declination in degrees. z_obs : 1-dimensional array of shape (n_objects) @@ -409,6 +462,8 @@ class PV_LogLikelihood(BaseFlowValidationModel): Catalogue kind, either "TFR", "SN", or "simple". name : str Name of the catalogue. + void_kwargs : dict, optional + Void data parameters. If `None` the data is not void data. with_num_dist_marginalisation : bool, optional Whether to use numerical distance marginalisation, in which case the tracers cannot be coupled by a covariance matrix. By default @@ -417,19 +472,31 @@ class PV_LogLikelihood(BaseFlowValidationModel): def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs, calibration_params, abs_calibration_params, mag_selection, - r_xrange, Omega_m, kind, name, with_num_dist_marginalisation): + r_xrange, Omega_m, kind, name, void_kwargs=None, + with_num_dist_marginalisation=True): if e_zobs is not None: e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) else: e2_cz_obs = jnp.zeros_like(z_obs) + self.is_void_data = void_kwargs is not None + + # This must be done before we convert to radians. + if void_kwargs is not None: + self._set_void_data(RA=RA, dec=dec, **void_kwargs) + # Convert RA/dec to radians. RA, dec = np.deg2rad(RA), np.deg2rad(dec) - names = ["log_los_density", "los_velocity", "RA", "dec", "z_obs", - "e2_cz_obs"] - values = [jnp.log(los_density), los_velocity, RA, dec, z_obs, - e2_cz_obs] + names = ["RA", "dec", "z_obs", "e2_cz_obs"] + values = [RA, dec, z_obs, e2_cz_obs] + + if not self.is_void_data: + names += ["_log_los_density", "_los_velocity"] + values += [jnp.log(los_density), los_velocity] + + # Set the void data + self._setattr_as_jax(names, values) self._set_calibration_params(calibration_params) self._set_abs_calibration_params(abs_calibration_params) @@ -660,17 +727,24 @@ class PV_LogLikelihood(BaseFlowValidationModel): mu_xrange[None, :], mu[:, None], e2_mu[:, None], self.log_r2_xrange[None, :]) + if self.is_void_data: + rLG = field_calibration_params["rLG"] + log_los_density = self.log_los_density(rLG=rLG) + los_velocity = self.los_velocity(rLG=rLG) + else: + log_los_density = self.log_los_density() + los_velocity = self.los_velocity() + # Inhomogeneous Malmquist bias. Shape: (nsims, ndata, nxrange) alpha = distmod_params["alpha"] - log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density + log_ptilde = log_ptilde[None, ...] + alpha * log_los_density ptilde = jnp.exp(log_ptilde) - # Normalization of p(r). Shape: (nsims, ndata) pnorm = simpson(ptilde, x=self.r_xrange, axis=-1) # Calculate z_obs at each distance. Shape: (nsims, ndata, nxrange) - vrad = field_calibration_params["beta"] * self.los_velocity + vrad = field_calibration_params["beta"] * los_velocity vrad += (Vext_rad[None, :, None] + Vmono) zobs = 1 + self.z_xrange[None, None, :] zobs *= 1 + vrad / SPEED_OF_LIGHT