Add mu sampling

This commit is contained in:
rstiskalek 2024-10-07 00:58:12 +01:00
parent 05123ec868
commit b99bf54789
2 changed files with 112 additions and 52 deletions

View file

@ -27,8 +27,10 @@ import numpy as np
from astropy import units as u
from astropy.coordinates import SkyCoord, angular_separation
from astropy.cosmology import FlatLambdaCDM, z_at_value
from interpax import interp1d
from jax import jit
from jax import numpy as jnp
from jax import vmap
from jax.scipy.special import erf, logsumexp
from numpyro import factor, plate, sample
from numpyro.distributions import MultivariateNormal, Normal, Uniform
@ -88,6 +90,39 @@ def distmod2dist(mu, Om0):
return (-0.000301) + (term1 * (term2 - (term3 - term4)))
def distmod2dist_gradient(mu, Om0):
"""
Calculate the derivative of comoving distance in `Mpc / h` with respect to
the distance modulus. The expression is valid for a flat universe over the
range of 0.00001 < z < 0.1.
"""
term1 = jnp.exp((0.443288 * mu) + (-14.286531))
dterm1 = 0.443288 * term1
term2 = (0.506973 * mu) + 12.954633
dterm2 = 0.506973
term3 = ((0.028134 * mu)**(((0.684713 * mu) + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) # noqa
ln_base = jnp.log(0.028134) + jnp.log(mu)
exponent = 0.835733 * mu + 1.235158 * Om0 - jnp.exp(0.072229 * mu)
exponent_derivative = 0.835733 - 0.072229 * jnp.exp(0.072229 * mu)
dterm3 = term3 * ((1 / mu) * exponent + exponent_derivative * ln_base)
term4 = (-0.045160) * mu
dterm4 = -0.045160
return (dterm1 * (term2 - (term3 - term4))
+ term1 * (dterm2 - (dterm3 - dterm4)))
def distmod2redshift(mu, Om0):
"""
Convert distance modulus to redshift, assuming `h = 1`. The expression is
valid for a flat universe over the range of 0.00001 < z < 0.1.
"""
return jnp.exp(((0.461108 * mu) - ((0.022187 * Om0) + (((0.022347 * mu)** (12.631788 - ((-6.708757) * Om0))) + 19.529852)))) # noqa
def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians):
"""Project the external velocity vector onto the line of sight."""
cos_dec = jnp.cos(dec_radians)
@ -150,6 +185,37 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
return normal_logpdf(x, loc, scale) - jnp.log(norm)
###############################################################################
# LOS interpolation #
###############################################################################
def interpolate_los(r, los, rgrid, method="cubic"):
"""
Interpolate the LOS field at a given radial distance.
Parameters
----------
r : 1-dimensional array of shape `(n_gal, )`
Radial distances at which to interpolate the LOS field.
los : 3-dimensional array of shape `(n_sims, n_gal, n_steps)`
LOS field.
rmin, rmax : float
Minimum and maximum radial distances in the data.
order : int, optional
The order of the interpolation. Default is 1, can be 0.
Returns
-------
2-dimensional array of shape `(n_sims, n_gal)`
"""
# Vectorize over the inner loop (ngal) first, then the outer loop (nsim)
def f(rn, los_row):
return interp1d(rn, rgrid, los_row, method=method)
return vmap(vmap(f, in_axes=(0, 0)), in_axes=(None, 0))(r, los)
###############################################################################
# Base flow validation #
###############################################################################
@ -291,6 +357,12 @@ class BaseFlowValidationModel(ABC):
return self._los_velocity
def log_los_density_at_r(self, r):
return interpolate_los(r, self.log_los_density(), self.r_xrange, )
def los_velocity_at_r(self, r):
return interpolate_los(r, self.los_velocity(), self.r_xrange, )
@abstractmethod
def __call__(self, **kwargs):
pass
@ -514,16 +586,16 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Name of the catalogue.
void_kwargs : dict, optional
Void data parameters. If `None` the data is not void data.
with_num_dist_marginalisation : bool, optional
Whether to use numerical distance marginalisation, in which case
the tracers cannot be coupled by a covariance matrix. By default
`True`.
wo_num_dist_marginalisation : bool, optional
Whether to directly sample the distance without numerical
marginalisation. in which case the tracers can be coupled by a
covariance matrix. By default `False`.
"""
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
calibration_params, abs_calibration_params, mag_selection,
r_xrange, Omega_m, kind, name, void_kwargs=None,
with_num_dist_marginalisation=True):
wo_num_dist_marginalisation=False):
if e_zobs is not None:
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
else:
@ -549,7 +621,7 @@ class PV_LogLikelihood(BaseFlowValidationModel):
values += [jnp.log(los_density), los_velocity]
# Density required only if not numerically marginalising.
if not with_num_dist_marginalisation:
if not wo_num_dist_marginalisation:
names += ["_los_density"]
values += [los_density]
@ -561,12 +633,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
self.kind = kind
self.name = name
self.Omega_m = Omega_m
self.with_num_dist_marginalisation = with_num_dist_marginalisation
self.wo_num_dist_marginalisation = wo_num_dist_marginalisation
self.norm = - self.ndata * jnp.log(self.num_sims)
# TODO: Somewhere here prepare the interpolators in case of no
# numerical marginalisation.
if mag_selection is not None:
self.mag_selection_kind = mag_selection["kind"]
@ -772,25 +841,16 @@ class PV_LogLikelihood(BaseFlowValidationModel):
# 2. Log-likelihood of the true distance and observed redshifts.
# The marginalisation of the true distance can be done numerically.
# ----------------------------------------------------------------
if self.with_num_dist_marginalisation:
if not self.wo_num_dist_marginalisation:
if field_calibration_params["sample_h"]:
raise NotImplementedError("Sampling of h not implemented.")
# Rescale the grid to account for the sampled H0. For distance
# modulus going from Mpc / h to Mpc implies larger numerical
# values, so there has to be a minus sign since h < 1.
# mu_xrange = self.mu_xrange - 5 * jnp.log(h)
# The redshift should also be boosted since now the object are
# further away?
# Actually, the redshift ought to remain the same?
else:
mu_xrange = self.mu_xrange
raise NotImplementedError(
"Sampling of 'h' is not supported if numerically "
"marginalising the true distance.")
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
log_ptilde = log_ptilde_wo_bias(
mu_xrange[None, :], mu[:, None], e2_mu[:, None],
self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
self.log_r2_xrange[None, :])
if self.is_void_data:
@ -832,56 +892,52 @@ class PV_LogLikelihood(BaseFlowValidationModel):
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
else:
if field_calibration_params["sample_h"]:
raise NotImplementedError("Sampling of h not implemented.")
raise NotImplementedError(
"Sampling of distance is not implemented. Work in progress.")
raise NotImplementedError(
"Sampling of h is not yet implemented.")
e_mu = jnp.sqrt(e2_mu)
# True distance modulus, shape is `(n_data)``
with plate("plate_mu", self.ndata):
mu_true = sample("mu", Normal(mu, e_mu))
# True distance, shape is `(n_data)``
# True distance and redshift, shape is `(n_data)`.
r_true = distmod2dist(mu_true, self.Omega_m)
# TODO:
z_true = None
z_true = distmod2redshift(mu_true, self.Omega_m)
if self.is_void_data:
raise NotImplementedError(
"Void data not implemented yet for distance sampling.")
else:
# grid log(density), shape is `(n_sims, n_data, n_rad)`
log_los_density_grid = self.los_density()
# TODO: Need to add the interpolators for these
# Grid log(density), shape is `(n_sims, n_data, n_rad)`
log_los_density_grid = self.log_los_density()
# Densities and velocities at the true distances, shape is
# `(n_sims, n_data)`
log_density = None
los_velocity = None
log_density = self.log_los_density_at_r(r_true)
los_velocity = self.los_velocity_at_r(r_true)
alpha = distmod_params["alpha"]
# Check dimensions of all this
# Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)`
pnorm = (
self.log_r2_xrange[None, None, :]
+ self.log_r2_xrange[None, None, :]
+ alpha * log_los_density_grid
+ normal_logpdf(
self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa
pnorm = jnp.exp(pnorm)
# Normalization of p(mu). Shape is now (nsims, ndata)
# Now integrate over the radial steps. Shape is `(nsims, ndata)`.
# No Jacobian here because I integrate over distance, not the
# distance modulus.
pnorm = simpson(pnorm, x=self.r_xrange, axis=-1)
# TODO: There should be a Jacobian?
# Jacobian |dr / dmu|_(mu_true), shape is `(n_data)`.
jac = jnp.abs(distmod2dist_gradient(mu_true, self.Omega_m))
# Calculate unnormalized log p(mu). Shape is (nsims, ndata)
ll = (
2 * (jnp.log(r_true) - self.log_r2_xrange_mean)[None, :]
+ jnp.log(jac)[None, :]
+ (2 * jnp.log(r_true) - self.log_r2_xrange_mean)[None, :]
+ alpha * log_density
+ normal_logpdf(mu_true, mu, e_mu)[None, :])
)
# Subtract the normalization. Shape remains (nsims, ndata)
ll -= jnp.log(pnorm)

View file

@ -88,7 +88,7 @@ def print_variables(names, variables):
def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
verbose=True):
wo_num_dist_marginalisation, verbose=True):
"""Load the data and create the NumPyro models."""
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
folder = "/mnt/extraspace/rstiskalek/catalogs/"
@ -128,6 +128,7 @@ def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
ksmooth=ARGS.ksmooth)
models[i] = csiborgtools.flow.get_model(
loader, mag_selection=mag_selection[i], void_kwargs=void_kwargs,
wo_num_dist_marginalisation=wo_num_dist_marginalisation,
**get_model_kwargs)
fprint(f"num. radial steps is {len(loader.rdist)}")
@ -299,10 +300,10 @@ if __name__ == "__main__":
###########################################################################
# `None` means default behaviour
nsteps = 10_000
nburn = 2_000
nsteps = 1_000
nburn = 1_000
zcmb_min = None
zcmb_max = 0.05
zcmb_max = 0.04
nchains_harmonic = 10
num_epochs = 50
inference_method = "mike"
@ -313,8 +314,9 @@ if __name__ == "__main__":
sample_Vmag_vax = False
sample_Vmono = False
sample_mag_dipole = False
wo_num_dist_marginalisation = True
absolute_calibration = None
calculate_harmonic = False if inference_method == "bayes" else True
calculate_harmonic = (False if inference_method == "bayes" else True) and (not wo_num_dist_marginalisation) # noqa
sample_h = True if absolute_calibration is not None else False
fname_kwargs = {"inference_method": inference_method,
@ -341,6 +343,7 @@ if __name__ == "__main__":
"num_epochs": num_epochs,
"inference_method": inference_method,
"sample_mag_dipole": sample_mag_dipole,
"wo_dist_marg": wo_num_dist_marginalisation,
"absolute_calibration": absolute_calibration,
"sample_h": sample_h,
}
@ -420,7 +423,8 @@ if __name__ == "__main__":
print(f"{'Current simulation:':<20} {i + 1} ({ksim}) out of {len(ksim_iterator)}.") # noqa
fname_kwargs["nsim"] = ksim
models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs)
models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
wo_num_dist_marginalisation)
model_kwargs = {
"models": models,
"field_calibration_hyperparams": calibration_hyperparams,