Add void support

This commit is contained in:
rstiskalek 2024-09-21 17:04:55 +01:00
parent 82b71922cc
commit f54eb34fd2

View file

@ -25,6 +25,7 @@ from abc import ABC, abstractmethod
import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord, angular_separation
from astropy.cosmology import FlatLambdaCDM, z_at_value
from jax import jit
from jax import numpy as jnp
@ -37,6 +38,7 @@ from tqdm import trange
from ..params import SPEED_OF_LIGHT
from ..utils import fprint
from .selection import toy_log_magnitude_selection
from .void_model import interpolate_void, load_void_data
H0 = 100 # km / s / Mpc
@ -193,6 +195,33 @@ class BaseFlowValidationModel(ABC):
self.z_xrange = jnp.asarray(z_xrange)
self.mu_xrange = jnp.asarray(mu_xrange)
def _set_void_data(self, RA, dec, kind, h, order):
"""Create the void interpolator."""
# h is the MOND model value of local H0 to convert the radial grid to
# Mpc / h
rLG_grid, void_grid = load_void_data(kind)
void_grid = jnp.asarray(void_grid, dtype=jnp.float32)
rLG_grid = jnp.asarray(rLG_grid, dtype=jnp.float32)
rLG_grid *= h
rLG_min, rLG_max = rLG_grid.min(), rLG_grid.max()
rgrid_min, rgrid_max = 0, 250
fprint(f"setting radial grid from {rLG_min} to {rLG_max} Mpc.")
rgrid_max *= h
# Get angular separation (in degrees) of each object from the model
# axis.
model_axis = SkyCoord(l=117, b=4, frame='galactic', unit='deg').icrs
coords = SkyCoord(ra=RA, dec=dec, unit='deg').icrs
phi = angular_separation(coords.ra.rad, coords.dec.rad,
model_axis.ra.rad, model_axis.dec.rad)
phi = jnp.asarray(phi * 180 / np.pi, dtype=jnp.float32)
self.void_interpolator = lambda rLG: interpolate_void(
rLG, self.r_xrange, phi, void_grid, rgrid_min, rgrid_max,
rLG_min, rLG_max, order)
@property
def ndata(self):
"""Number of PV objects in the catalogue."""
@ -201,7 +230,24 @@ class BaseFlowValidationModel(ABC):
@property
def num_sims(self):
"""Number of simulations."""
return len(self.log_los_density)
if self.is_void_data:
return 1.
return len(self.log_los_density())
def log_los_density(self, **kwargs):
if self.is_void_data:
# Currently we have no densities for the void.
return jnp.zeros((1, self.ndata, len(self.r_xrange)))
return self._log_los_density
def los_velocity(self, **kwargs):
if self.is_void_data:
# We want the shape to be `(1, n_objects, n_radial_steps)``.
return self.void_interpolator(kwargs["rLG"])[None, ...]
return self._los_velocity
@abstractmethod
def __call__(self, **kwargs):
@ -331,7 +377,8 @@ def sample_simple(e_mu_min, e_mu_max, dmu_min, dmu_max, alpha_min, alpha_max,
def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
beta_max, sigma_v_min, sigma_v_max, h_min, h_max,
no_Vext, sample_Vmono, sample_beta, sample_h):
rLG_min, rLG_max, no_Vext, sample_Vmono, sample_beta,
sample_h, sample_rLG):
"""Sample the flow calibration."""
sigma_v = sample("sigma_v", Uniform(sigma_v_min, sigma_v_max))
@ -357,12 +404,18 @@ def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
else:
h = 1.0
if sample_rLG:
rLG = sample("rLG", Uniform(rLG_min, rLG_max))
else:
rLG = None
return {"Vext": Vext,
"Vmono": Vmono,
"sigma_v": sigma_v,
"beta": beta,
"h": h,
"sample_h": sample_h,
"rLG": rLG,
}
@ -386,9 +439,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Parameters
----------
los_density : 3-dimensional array of shape (n_sims, n_objects, n_steps)
LOS density field.
LOS density field. Set to `None` if the data is void data.
los_velocity : 3-dimensional array of shape (n_sims, n_objects, n_steps)
LOS radial velocity field.
LOS radial velocity field. Set to `None` if the data is void data.
RA, dec : 1-dimensional arrays of shape (n_objects)
Right ascension and declination in degrees.
z_obs : 1-dimensional array of shape (n_objects)
@ -409,6 +462,8 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Catalogue kind, either "TFR", "SN", or "simple".
name : str
Name of the catalogue.
void_kwargs : dict, optional
Void data parameters. If `None` the data is not void data.
with_num_dist_marginalisation : bool, optional
Whether to use numerical distance marginalisation, in which case
the tracers cannot be coupled by a covariance matrix. By default
@ -417,19 +472,31 @@ class PV_LogLikelihood(BaseFlowValidationModel):
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
calibration_params, abs_calibration_params, mag_selection,
r_xrange, Omega_m, kind, name, with_num_dist_marginalisation):
r_xrange, Omega_m, kind, name, void_kwargs=None,
with_num_dist_marginalisation=True):
if e_zobs is not None:
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
else:
e2_cz_obs = jnp.zeros_like(z_obs)
self.is_void_data = void_kwargs is not None
# This must be done before we convert to radians.
if void_kwargs is not None:
self._set_void_data(RA=RA, dec=dec, **void_kwargs)
# Convert RA/dec to radians.
RA, dec = np.deg2rad(RA), np.deg2rad(dec)
names = ["log_los_density", "los_velocity", "RA", "dec", "z_obs",
"e2_cz_obs"]
values = [jnp.log(los_density), los_velocity, RA, dec, z_obs,
e2_cz_obs]
names = ["RA", "dec", "z_obs", "e2_cz_obs"]
values = [RA, dec, z_obs, e2_cz_obs]
if not self.is_void_data:
names += ["_log_los_density", "_los_velocity"]
values += [jnp.log(los_density), los_velocity]
# Set the void data
self._setattr_as_jax(names, values)
self._set_calibration_params(calibration_params)
self._set_abs_calibration_params(abs_calibration_params)
@ -660,17 +727,24 @@ class PV_LogLikelihood(BaseFlowValidationModel):
mu_xrange[None, :], mu[:, None], e2_mu[:, None],
self.log_r2_xrange[None, :])
if self.is_void_data:
rLG = field_calibration_params["rLG"]
log_los_density = self.log_los_density(rLG=rLG)
los_velocity = self.los_velocity(rLG=rLG)
else:
log_los_density = self.log_los_density()
los_velocity = self.los_velocity()
# Inhomogeneous Malmquist bias. Shape: (nsims, ndata, nxrange)
alpha = distmod_params["alpha"]
log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density
log_ptilde = log_ptilde[None, ...] + alpha * log_los_density
ptilde = jnp.exp(log_ptilde)
# Normalization of p(r). Shape: (nsims, ndata)
pnorm = simpson(ptilde, x=self.r_xrange, axis=-1)
# Calculate z_obs at each distance. Shape: (nsims, ndata, nxrange)
vrad = field_calibration_params["beta"] * self.los_velocity
vrad = field_calibration_params["beta"] * los_velocity
vrad += (Vext_rad[None, :, None] + Vmono)
zobs = 1 + self.z_xrange[None, None, :]
zobs *= 1 + vrad / SPEED_OF_LIGHT