mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-07-06 14:31:11 +00:00
Add mu sampling
This commit is contained in:
parent
05123ec868
commit
b99bf54789
2 changed files with 112 additions and 52 deletions
|
@ -27,8 +27,10 @@ import numpy as np
|
|||
from astropy import units as u
|
||||
from astropy.coordinates import SkyCoord, angular_separation
|
||||
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
||||
from interpax import interp1d
|
||||
from jax import jit
|
||||
from jax import numpy as jnp
|
||||
from jax import vmap
|
||||
from jax.scipy.special import erf, logsumexp
|
||||
from numpyro import factor, plate, sample
|
||||
from numpyro.distributions import MultivariateNormal, Normal, Uniform
|
||||
|
@ -88,6 +90,39 @@ def distmod2dist(mu, Om0):
|
|||
return (-0.000301) + (term1 * (term2 - (term3 - term4)))
|
||||
|
||||
|
||||
def distmod2dist_gradient(mu, Om0):
|
||||
"""
|
||||
Calculate the derivative of comoving distance in `Mpc / h` with respect to
|
||||
the distance modulus. The expression is valid for a flat universe over the
|
||||
range of 0.00001 < z < 0.1.
|
||||
"""
|
||||
term1 = jnp.exp((0.443288 * mu) + (-14.286531))
|
||||
dterm1 = 0.443288 * term1
|
||||
|
||||
term2 = (0.506973 * mu) + 12.954633
|
||||
dterm2 = 0.506973
|
||||
|
||||
term3 = ((0.028134 * mu)**(((0.684713 * mu) + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) # noqa
|
||||
ln_base = jnp.log(0.028134) + jnp.log(mu)
|
||||
exponent = 0.835733 * mu + 1.235158 * Om0 - jnp.exp(0.072229 * mu)
|
||||
exponent_derivative = 0.835733 - 0.072229 * jnp.exp(0.072229 * mu)
|
||||
dterm3 = term3 * ((1 / mu) * exponent + exponent_derivative * ln_base)
|
||||
|
||||
term4 = (-0.045160) * mu
|
||||
dterm4 = -0.045160
|
||||
|
||||
return (dterm1 * (term2 - (term3 - term4))
|
||||
+ term1 * (dterm2 - (dterm3 - dterm4)))
|
||||
|
||||
|
||||
def distmod2redshift(mu, Om0):
|
||||
"""
|
||||
Convert distance modulus to redshift, assuming `h = 1`. The expression is
|
||||
valid for a flat universe over the range of 0.00001 < z < 0.1.
|
||||
"""
|
||||
return jnp.exp(((0.461108 * mu) - ((0.022187 * Om0) + (((0.022347 * mu)** (12.631788 - ((-6.708757) * Om0))) + 19.529852)))) # noqa
|
||||
|
||||
|
||||
def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians):
|
||||
"""Project the external velocity vector onto the line of sight."""
|
||||
cos_dec = jnp.cos(dec_radians)
|
||||
|
@ -150,6 +185,37 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
|
|||
return normal_logpdf(x, loc, scale) - jnp.log(norm)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# LOS interpolation #
|
||||
###############################################################################
|
||||
|
||||
|
||||
def interpolate_los(r, los, rgrid, method="cubic"):
|
||||
"""
|
||||
Interpolate the LOS field at a given radial distance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
r : 1-dimensional array of shape `(n_gal, )`
|
||||
Radial distances at which to interpolate the LOS field.
|
||||
los : 3-dimensional array of shape `(n_sims, n_gal, n_steps)`
|
||||
LOS field.
|
||||
rmin, rmax : float
|
||||
Minimum and maximum radial distances in the data.
|
||||
order : int, optional
|
||||
The order of the interpolation. Default is 1, can be 0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
2-dimensional array of shape `(n_sims, n_gal)`
|
||||
"""
|
||||
# Vectorize over the inner loop (ngal) first, then the outer loop (nsim)
|
||||
def f(rn, los_row):
|
||||
return interp1d(rn, rgrid, los_row, method=method)
|
||||
|
||||
return vmap(vmap(f, in_axes=(0, 0)), in_axes=(None, 0))(r, los)
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Base flow validation #
|
||||
###############################################################################
|
||||
|
@ -291,6 +357,12 @@ class BaseFlowValidationModel(ABC):
|
|||
|
||||
return self._los_velocity
|
||||
|
||||
def log_los_density_at_r(self, r):
|
||||
return interpolate_los(r, self.log_los_density(), self.r_xrange, )
|
||||
|
||||
def los_velocity_at_r(self, r):
|
||||
return interpolate_los(r, self.los_velocity(), self.r_xrange, )
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, **kwargs):
|
||||
pass
|
||||
|
@ -514,16 +586,16 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
|||
Name of the catalogue.
|
||||
void_kwargs : dict, optional
|
||||
Void data parameters. If `None` the data is not void data.
|
||||
with_num_dist_marginalisation : bool, optional
|
||||
Whether to use numerical distance marginalisation, in which case
|
||||
the tracers cannot be coupled by a covariance matrix. By default
|
||||
`True`.
|
||||
wo_num_dist_marginalisation : bool, optional
|
||||
Whether to directly sample the distance without numerical
|
||||
marginalisation. in which case the tracers can be coupled by a
|
||||
covariance matrix. By default `False`.
|
||||
"""
|
||||
|
||||
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
|
||||
calibration_params, abs_calibration_params, mag_selection,
|
||||
r_xrange, Omega_m, kind, name, void_kwargs=None,
|
||||
with_num_dist_marginalisation=True):
|
||||
wo_num_dist_marginalisation=False):
|
||||
if e_zobs is not None:
|
||||
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
||||
else:
|
||||
|
@ -549,7 +621,7 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
|||
values += [jnp.log(los_density), los_velocity]
|
||||
|
||||
# Density required only if not numerically marginalising.
|
||||
if not with_num_dist_marginalisation:
|
||||
if not wo_num_dist_marginalisation:
|
||||
names += ["_los_density"]
|
||||
values += [los_density]
|
||||
|
||||
|
@ -561,12 +633,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
|||
self.kind = kind
|
||||
self.name = name
|
||||
self.Omega_m = Omega_m
|
||||
self.with_num_dist_marginalisation = with_num_dist_marginalisation
|
||||
self.wo_num_dist_marginalisation = wo_num_dist_marginalisation
|
||||
self.norm = - self.ndata * jnp.log(self.num_sims)
|
||||
|
||||
# TODO: Somewhere here prepare the interpolators in case of no
|
||||
# numerical marginalisation.
|
||||
|
||||
if mag_selection is not None:
|
||||
self.mag_selection_kind = mag_selection["kind"]
|
||||
|
||||
|
@ -772,25 +841,16 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
|||
# 2. Log-likelihood of the true distance and observed redshifts.
|
||||
# The marginalisation of the true distance can be done numerically.
|
||||
# ----------------------------------------------------------------
|
||||
if self.with_num_dist_marginalisation:
|
||||
if not self.wo_num_dist_marginalisation:
|
||||
|
||||
if field_calibration_params["sample_h"]:
|
||||
raise NotImplementedError("Sampling of h not implemented.")
|
||||
# Rescale the grid to account for the sampled H0. For distance
|
||||
# modulus going from Mpc / h to Mpc implies larger numerical
|
||||
# values, so there has to be a minus sign since h < 1.
|
||||
# mu_xrange = self.mu_xrange - 5 * jnp.log(h)
|
||||
|
||||
# The redshift should also be boosted since now the object are
|
||||
# further away?
|
||||
|
||||
# Actually, the redshift ought to remain the same?
|
||||
else:
|
||||
mu_xrange = self.mu_xrange
|
||||
raise NotImplementedError(
|
||||
"Sampling of 'h' is not supported if numerically "
|
||||
"marginalising the true distance.")
|
||||
|
||||
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
|
||||
log_ptilde = log_ptilde_wo_bias(
|
||||
mu_xrange[None, :], mu[:, None], e2_mu[:, None],
|
||||
self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
|
||||
self.log_r2_xrange[None, :])
|
||||
|
||||
if self.is_void_data:
|
||||
|
@ -832,56 +892,52 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
|||
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
||||
else:
|
||||
if field_calibration_params["sample_h"]:
|
||||
raise NotImplementedError("Sampling of h not implemented.")
|
||||
|
||||
raise NotImplementedError(
|
||||
"Sampling of distance is not implemented. Work in progress.")
|
||||
raise NotImplementedError(
|
||||
"Sampling of h is not yet implemented.")
|
||||
|
||||
e_mu = jnp.sqrt(e2_mu)
|
||||
# True distance modulus, shape is `(n_data)``
|
||||
with plate("plate_mu", self.ndata):
|
||||
mu_true = sample("mu", Normal(mu, e_mu))
|
||||
|
||||
# True distance, shape is `(n_data)``
|
||||
# True distance and redshift, shape is `(n_data)`.
|
||||
r_true = distmod2dist(mu_true, self.Omega_m)
|
||||
# TODO:
|
||||
z_true = None
|
||||
z_true = distmod2redshift(mu_true, self.Omega_m)
|
||||
|
||||
if self.is_void_data:
|
||||
raise NotImplementedError(
|
||||
"Void data not implemented yet for distance sampling.")
|
||||
else:
|
||||
# grid log(density), shape is `(n_sims, n_data, n_rad)`
|
||||
log_los_density_grid = self.los_density()
|
||||
|
||||
# TODO: Need to add the interpolators for these
|
||||
# Grid log(density), shape is `(n_sims, n_data, n_rad)`
|
||||
log_los_density_grid = self.log_los_density()
|
||||
# Densities and velocities at the true distances, shape is
|
||||
# `(n_sims, n_data)`
|
||||
log_density = None
|
||||
los_velocity = None
|
||||
log_density = self.log_los_density_at_r(r_true)
|
||||
los_velocity = self.los_velocity_at_r(r_true)
|
||||
|
||||
alpha = distmod_params["alpha"]
|
||||
|
||||
# Check dimensions of all this
|
||||
|
||||
# Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)`
|
||||
pnorm = (
|
||||
self.log_r2_xrange[None, None, :]
|
||||
+ self.log_r2_xrange[None, None, :]
|
||||
+ alpha * log_los_density_grid
|
||||
+ normal_logpdf(
|
||||
self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa
|
||||
|
||||
pnorm = jnp.exp(pnorm)
|
||||
|
||||
# Normalization of p(mu). Shape is now (nsims, ndata)
|
||||
# Now integrate over the radial steps. Shape is `(nsims, ndata)`.
|
||||
# No Jacobian here because I integrate over distance, not the
|
||||
# distance modulus.
|
||||
pnorm = simpson(pnorm, x=self.r_xrange, axis=-1)
|
||||
|
||||
# TODO: There should be a Jacobian?
|
||||
# Jacobian |dr / dmu|_(mu_true), shape is `(n_data)`.
|
||||
jac = jnp.abs(distmod2dist_gradient(mu_true, self.Omega_m))
|
||||
|
||||
# Calculate unnormalized log p(mu). Shape is (nsims, ndata)
|
||||
ll = (
|
||||
2 * (jnp.log(r_true) - self.log_r2_xrange_mean)[None, :]
|
||||
+ jnp.log(jac)[None, :]
|
||||
+ (2 * jnp.log(r_true) - self.log_r2_xrange_mean)[None, :]
|
||||
+ alpha * log_density
|
||||
+ normal_logpdf(mu_true, mu, e_mu)[None, :])
|
||||
)
|
||||
|
||||
# Subtract the normalization. Shape remains (nsims, ndata)
|
||||
ll -= jnp.log(pnorm)
|
||||
|
|
|
@ -88,7 +88,7 @@ def print_variables(names, variables):
|
|||
|
||||
|
||||
def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
|
||||
verbose=True):
|
||||
wo_num_dist_marginalisation, verbose=True):
|
||||
"""Load the data and create the NumPyro models."""
|
||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
||||
|
@ -128,6 +128,7 @@ def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
|
|||
ksmooth=ARGS.ksmooth)
|
||||
models[i] = csiborgtools.flow.get_model(
|
||||
loader, mag_selection=mag_selection[i], void_kwargs=void_kwargs,
|
||||
wo_num_dist_marginalisation=wo_num_dist_marginalisation,
|
||||
**get_model_kwargs)
|
||||
|
||||
fprint(f"num. radial steps is {len(loader.rdist)}")
|
||||
|
@ -299,10 +300,10 @@ if __name__ == "__main__":
|
|||
###########################################################################
|
||||
|
||||
# `None` means default behaviour
|
||||
nsteps = 10_000
|
||||
nburn = 2_000
|
||||
nsteps = 1_000
|
||||
nburn = 1_000
|
||||
zcmb_min = None
|
||||
zcmb_max = 0.05
|
||||
zcmb_max = 0.04
|
||||
nchains_harmonic = 10
|
||||
num_epochs = 50
|
||||
inference_method = "mike"
|
||||
|
@ -313,8 +314,9 @@ if __name__ == "__main__":
|
|||
sample_Vmag_vax = False
|
||||
sample_Vmono = False
|
||||
sample_mag_dipole = False
|
||||
wo_num_dist_marginalisation = True
|
||||
absolute_calibration = None
|
||||
calculate_harmonic = False if inference_method == "bayes" else True
|
||||
calculate_harmonic = (False if inference_method == "bayes" else True) and (not wo_num_dist_marginalisation) # noqa
|
||||
sample_h = True if absolute_calibration is not None else False
|
||||
|
||||
fname_kwargs = {"inference_method": inference_method,
|
||||
|
@ -341,6 +343,7 @@ if __name__ == "__main__":
|
|||
"num_epochs": num_epochs,
|
||||
"inference_method": inference_method,
|
||||
"sample_mag_dipole": sample_mag_dipole,
|
||||
"wo_dist_marg": wo_num_dist_marginalisation,
|
||||
"absolute_calibration": absolute_calibration,
|
||||
"sample_h": sample_h,
|
||||
}
|
||||
|
@ -420,7 +423,8 @@ if __name__ == "__main__":
|
|||
print(f"{'Current simulation:':<20} {i + 1} ({ksim}) out of {len(ksim_iterator)}.") # noqa
|
||||
|
||||
fname_kwargs["nsim"] = ksim
|
||||
models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs)
|
||||
models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
|
||||
wo_num_dist_marginalisation)
|
||||
model_kwargs = {
|
||||
"models": models,
|
||||
"field_calibration_hyperparams": calibration_hyperparams,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue