diff --git a/csiborgtools/flow/flow_model.py b/csiborgtools/flow/flow_model.py index d2443e6..7bb358d 100644 --- a/csiborgtools/flow/flow_model.py +++ b/csiborgtools/flow/flow_model.py @@ -25,7 +25,6 @@ from abc import ABC, abstractmethod import numpy as np from astropy import units as u -from astropy.coordinates import SkyCoord, angular_separation from astropy.cosmology import FlatLambdaCDM, z_at_value from interpax import interp1d from jax import jit @@ -39,90 +38,19 @@ from tqdm import trange from ..params import SPEED_OF_LIGHT from ..utils import fprint +from .cosmography import (dist2redshift, distmod2dist, distmod2dist_gradient, + distmod2redshift, gradient_redshift2dist) from .selection import toy_log_magnitude_selection -from .void_model import interpolate_void, load_void_data +from .void_model import (angular_distance_from_void_axis, interpolate_void, + load_void_data) H0 = 100 # km / s / Mpc ############################################################################### -# JAX Flow model # +# Various flow utilities # ############################################################################### -def dist2redshift(dist, Omega_m, h=1.): - """ - Convert comoving distance to cosmological redshift if the Universe is - flat and z << 1. - """ - eta = 3 * Omega_m / 2 - return 1 / eta * (1 - (1 - 2 * 100 * h * dist / SPEED_OF_LIGHT * eta)**0.5) - - -def redshift2dist(z, Omega_m): - """ - Convert cosmological redshift to comoving distance if the Universe is - flat and z << 1. - """ - q0 = 3 * Omega_m / 2 - 1 - return SPEED_OF_LIGHT * z / (2 * H0) * (2 - z * (1 + q0)) - - -def gradient_redshift2dist(z, Omega_m): - """ - Gradient of the redshift to comoving distance conversion if the Universe is - flat and z << 1. - """ - q0 = 3 * Omega_m / 2 - 1 - return SPEED_OF_LIGHT / H0 * (1 - z * (1 + q0)) - - -def distmod2dist(mu, Om0): - """ - Convert distance modulus to distance in `Mpc / h`. The expression is valid - for a flat universe over the range of 0.00001 < z < 0.1. - """ - term1 = jnp.exp((0.443288 * mu) + (-14.286531)) - term2 = (0.506973 * mu) + 12.954633 - term3 = ((0.028134 * mu) ** ( - ((0.684713 * mu) - + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) - term4 = (-0.045160) * mu - return (-0.000301) + (term1 * (term2 - (term3 - term4))) - - -def distmod2dist_gradient(mu, Om0): - """ - Calculate the derivative of comoving distance in `Mpc / h` with respect to - the distance modulus. The expression is valid for a flat universe over the - range of 0.00001 < z < 0.1. - """ - term1 = jnp.exp((0.443288 * mu) + (-14.286531)) - dterm1 = 0.443288 * term1 - - term2 = (0.506973 * mu) + 12.954633 - dterm2 = 0.506973 - - term3 = ((0.028134 * mu)**(((0.684713 * mu) + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) # noqa - ln_base = jnp.log(0.028134) + jnp.log(mu) - exponent = 0.835733 * mu + 1.235158 * Om0 - jnp.exp(0.072229 * mu) - exponent_derivative = 0.835733 - 0.072229 * jnp.exp(0.072229 * mu) - dterm3 = term3 * ((1 / mu) * exponent + exponent_derivative * ln_base) - - term4 = (-0.045160) * mu - dterm4 = -0.045160 - - return (dterm1 * (term2 - (term3 - term4)) - + term1 * (dterm2 - (dterm3 - dterm4))) - - -def distmod2redshift(mu, Om0): - """ - Convert distance modulus to redshift, assuming `h = 1`. The expression is - valid for a flat universe over the range of 0.00001 < z < 0.1. - """ - return jnp.exp(((0.461108 * mu) - ((0.022187 * Om0) + (((0.022347 * mu)** (12.631788 - ((-6.708757) * Om0))) + 19.529852)))) # noqa - - def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians): """Project the external velocity vector onto the line of sight.""" cos_dec = jnp.cos(dec_radians) @@ -298,17 +226,12 @@ class BaseFlowValidationModel(ABC): rLG_grid *= h rLG_min, rLG_max = rLG_grid.min(), rLG_grid.max() rgrid_min, rgrid_max = 0, 250 - fprint(f"setting radial grid from {rLG_min} to {rLG_max} Mpc.") + fprint(f"setting radial grid from {rLG_min} to {rLG_max} Mpc / h.") rgrid_max *= h - # Get angular separation (in degrees) of each object from the model - # axis. - model_axis = SkyCoord(l=117, b=4, frame='galactic', unit='deg').icrs - coords = SkyCoord(ra=RA, dec=dec, unit='deg').icrs - - phi = angular_separation(coords.ra.rad, coords.dec.rad, - model_axis.ra.rad, model_axis.dec.rad) - phi = jnp.asarray(phi * 180 / np.pi, dtype=jnp.float32) + # Get angular separation of each object from the model axis. + phi = angular_distance_from_void_axis(RA, dec) + phi = jnp.asarray(phi, dtype=jnp.float32) if kind == "density": void_grid = jnp.log(void_grid) @@ -836,7 +759,6 @@ class PV_LogLikelihood(BaseFlowValidationModel): else: raise ValueError(f"Unknown kind: `{self.kind}`.") - # h = field_calibration_params["h"] # ---------------------------------------------------------------- # 2. Log-likelihood of the true distance and observed redshifts. # The marginalisation of the true distance can be done numerically. @@ -989,7 +911,7 @@ def PV_validation_model(models, distmod_hyperparams_per_model, # We sample the components of Vext with a uniform prior, which means # there is a |Vext|^2 prior, we correct for this so that the sampling # is effecitvely uniformly in magnitude of Vext and angles. - if "Vext" in field_calibration_params: + if "Vext" in field_calibration_params and not field_calibration_hyperparams["no_Vext"]: # noqa ll -= jnp.log(jnp.sum(field_calibration_params["Vext"]**2)) for n in range(len(models)):