mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2025-04-19 06:40:54 +00:00
Check Vext likelihoo
This commit is contained in:
parent
fa50e62fbe
commit
d7da107d1c
1 changed files with 10 additions and 88 deletions
|
@ -25,7 +25,6 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from astropy import units as u
|
from astropy import units as u
|
||||||
from astropy.coordinates import SkyCoord, angular_separation
|
|
||||||
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
||||||
from interpax import interp1d
|
from interpax import interp1d
|
||||||
from jax import jit
|
from jax import jit
|
||||||
|
@ -39,90 +38,19 @@ from tqdm import trange
|
||||||
|
|
||||||
from ..params import SPEED_OF_LIGHT
|
from ..params import SPEED_OF_LIGHT
|
||||||
from ..utils import fprint
|
from ..utils import fprint
|
||||||
|
from .cosmography import (dist2redshift, distmod2dist, distmod2dist_gradient,
|
||||||
|
distmod2redshift, gradient_redshift2dist)
|
||||||
from .selection import toy_log_magnitude_selection
|
from .selection import toy_log_magnitude_selection
|
||||||
from .void_model import interpolate_void, load_void_data
|
from .void_model import (angular_distance_from_void_axis, interpolate_void,
|
||||||
|
load_void_data)
|
||||||
|
|
||||||
H0 = 100 # km / s / Mpc
|
H0 = 100 # km / s / Mpc
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# JAX Flow model #
|
# Various flow utilities #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
def dist2redshift(dist, Omega_m, h=1.):
|
|
||||||
"""
|
|
||||||
Convert comoving distance to cosmological redshift if the Universe is
|
|
||||||
flat and z << 1.
|
|
||||||
"""
|
|
||||||
eta = 3 * Omega_m / 2
|
|
||||||
return 1 / eta * (1 - (1 - 2 * 100 * h * dist / SPEED_OF_LIGHT * eta)**0.5)
|
|
||||||
|
|
||||||
|
|
||||||
def redshift2dist(z, Omega_m):
|
|
||||||
"""
|
|
||||||
Convert cosmological redshift to comoving distance if the Universe is
|
|
||||||
flat and z << 1.
|
|
||||||
"""
|
|
||||||
q0 = 3 * Omega_m / 2 - 1
|
|
||||||
return SPEED_OF_LIGHT * z / (2 * H0) * (2 - z * (1 + q0))
|
|
||||||
|
|
||||||
|
|
||||||
def gradient_redshift2dist(z, Omega_m):
|
|
||||||
"""
|
|
||||||
Gradient of the redshift to comoving distance conversion if the Universe is
|
|
||||||
flat and z << 1.
|
|
||||||
"""
|
|
||||||
q0 = 3 * Omega_m / 2 - 1
|
|
||||||
return SPEED_OF_LIGHT / H0 * (1 - z * (1 + q0))
|
|
||||||
|
|
||||||
|
|
||||||
def distmod2dist(mu, Om0):
|
|
||||||
"""
|
|
||||||
Convert distance modulus to distance in `Mpc / h`. The expression is valid
|
|
||||||
for a flat universe over the range of 0.00001 < z < 0.1.
|
|
||||||
"""
|
|
||||||
term1 = jnp.exp((0.443288 * mu) + (-14.286531))
|
|
||||||
term2 = (0.506973 * mu) + 12.954633
|
|
||||||
term3 = ((0.028134 * mu) ** (
|
|
||||||
((0.684713 * mu)
|
|
||||||
+ ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu)))
|
|
||||||
term4 = (-0.045160) * mu
|
|
||||||
return (-0.000301) + (term1 * (term2 - (term3 - term4)))
|
|
||||||
|
|
||||||
|
|
||||||
def distmod2dist_gradient(mu, Om0):
|
|
||||||
"""
|
|
||||||
Calculate the derivative of comoving distance in `Mpc / h` with respect to
|
|
||||||
the distance modulus. The expression is valid for a flat universe over the
|
|
||||||
range of 0.00001 < z < 0.1.
|
|
||||||
"""
|
|
||||||
term1 = jnp.exp((0.443288 * mu) + (-14.286531))
|
|
||||||
dterm1 = 0.443288 * term1
|
|
||||||
|
|
||||||
term2 = (0.506973 * mu) + 12.954633
|
|
||||||
dterm2 = 0.506973
|
|
||||||
|
|
||||||
term3 = ((0.028134 * mu)**(((0.684713 * mu) + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) # noqa
|
|
||||||
ln_base = jnp.log(0.028134) + jnp.log(mu)
|
|
||||||
exponent = 0.835733 * mu + 1.235158 * Om0 - jnp.exp(0.072229 * mu)
|
|
||||||
exponent_derivative = 0.835733 - 0.072229 * jnp.exp(0.072229 * mu)
|
|
||||||
dterm3 = term3 * ((1 / mu) * exponent + exponent_derivative * ln_base)
|
|
||||||
|
|
||||||
term4 = (-0.045160) * mu
|
|
||||||
dterm4 = -0.045160
|
|
||||||
|
|
||||||
return (dterm1 * (term2 - (term3 - term4))
|
|
||||||
+ term1 * (dterm2 - (dterm3 - dterm4)))
|
|
||||||
|
|
||||||
|
|
||||||
def distmod2redshift(mu, Om0):
|
|
||||||
"""
|
|
||||||
Convert distance modulus to redshift, assuming `h = 1`. The expression is
|
|
||||||
valid for a flat universe over the range of 0.00001 < z < 0.1.
|
|
||||||
"""
|
|
||||||
return jnp.exp(((0.461108 * mu) - ((0.022187 * Om0) + (((0.022347 * mu)** (12.631788 - ((-6.708757) * Om0))) + 19.529852)))) # noqa
|
|
||||||
|
|
||||||
|
|
||||||
def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians):
|
def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians):
|
||||||
"""Project the external velocity vector onto the line of sight."""
|
"""Project the external velocity vector onto the line of sight."""
|
||||||
cos_dec = jnp.cos(dec_radians)
|
cos_dec = jnp.cos(dec_radians)
|
||||||
|
@ -298,17 +226,12 @@ class BaseFlowValidationModel(ABC):
|
||||||
rLG_grid *= h
|
rLG_grid *= h
|
||||||
rLG_min, rLG_max = rLG_grid.min(), rLG_grid.max()
|
rLG_min, rLG_max = rLG_grid.min(), rLG_grid.max()
|
||||||
rgrid_min, rgrid_max = 0, 250
|
rgrid_min, rgrid_max = 0, 250
|
||||||
fprint(f"setting radial grid from {rLG_min} to {rLG_max} Mpc.")
|
fprint(f"setting radial grid from {rLG_min} to {rLG_max} Mpc / h.")
|
||||||
rgrid_max *= h
|
rgrid_max *= h
|
||||||
|
|
||||||
# Get angular separation (in degrees) of each object from the model
|
# Get angular separation of each object from the model axis.
|
||||||
# axis.
|
phi = angular_distance_from_void_axis(RA, dec)
|
||||||
model_axis = SkyCoord(l=117, b=4, frame='galactic', unit='deg').icrs
|
phi = jnp.asarray(phi, dtype=jnp.float32)
|
||||||
coords = SkyCoord(ra=RA, dec=dec, unit='deg').icrs
|
|
||||||
|
|
||||||
phi = angular_separation(coords.ra.rad, coords.dec.rad,
|
|
||||||
model_axis.ra.rad, model_axis.dec.rad)
|
|
||||||
phi = jnp.asarray(phi * 180 / np.pi, dtype=jnp.float32)
|
|
||||||
|
|
||||||
if kind == "density":
|
if kind == "density":
|
||||||
void_grid = jnp.log(void_grid)
|
void_grid = jnp.log(void_grid)
|
||||||
|
@ -836,7 +759,6 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown kind: `{self.kind}`.")
|
raise ValueError(f"Unknown kind: `{self.kind}`.")
|
||||||
|
|
||||||
# h = field_calibration_params["h"]
|
|
||||||
# ----------------------------------------------------------------
|
# ----------------------------------------------------------------
|
||||||
# 2. Log-likelihood of the true distance and observed redshifts.
|
# 2. Log-likelihood of the true distance and observed redshifts.
|
||||||
# The marginalisation of the true distance can be done numerically.
|
# The marginalisation of the true distance can be done numerically.
|
||||||
|
@ -989,7 +911,7 @@ def PV_validation_model(models, distmod_hyperparams_per_model,
|
||||||
# We sample the components of Vext with a uniform prior, which means
|
# We sample the components of Vext with a uniform prior, which means
|
||||||
# there is a |Vext|^2 prior, we correct for this so that the sampling
|
# there is a |Vext|^2 prior, we correct for this so that the sampling
|
||||||
# is effecitvely uniformly in magnitude of Vext and angles.
|
# is effecitvely uniformly in magnitude of Vext and angles.
|
||||||
if "Vext" in field_calibration_params:
|
if "Vext" in field_calibration_params and not field_calibration_hyperparams["no_Vext"]: # noqa
|
||||||
ll -= jnp.log(jnp.sum(field_calibration_params["Vext"]**2))
|
ll -= jnp.log(jnp.sum(field_calibration_params["Vext"]**2))
|
||||||
|
|
||||||
for n in range(len(models)):
|
for n in range(len(models)):
|
||||||
|
|
Loading…
Add table
Reference in a new issue