Add mu sampling

This commit is contained in:
rstiskalek 2024-10-07 00:58:12 +01:00
parent 05123ec868
commit b99bf54789
2 changed files with 112 additions and 52 deletions

View file

@ -27,8 +27,10 @@ import numpy as np
from astropy import units as u from astropy import units as u
from astropy.coordinates import SkyCoord, angular_separation from astropy.coordinates import SkyCoord, angular_separation
from astropy.cosmology import FlatLambdaCDM, z_at_value from astropy.cosmology import FlatLambdaCDM, z_at_value
from interpax import interp1d
from jax import jit from jax import jit
from jax import numpy as jnp from jax import numpy as jnp
from jax import vmap
from jax.scipy.special import erf, logsumexp from jax.scipy.special import erf, logsumexp
from numpyro import factor, plate, sample from numpyro import factor, plate, sample
from numpyro.distributions import MultivariateNormal, Normal, Uniform from numpyro.distributions import MultivariateNormal, Normal, Uniform
@ -88,6 +90,39 @@ def distmod2dist(mu, Om0):
return (-0.000301) + (term1 * (term2 - (term3 - term4))) return (-0.000301) + (term1 * (term2 - (term3 - term4)))
def distmod2dist_gradient(mu, Om0):
"""
Calculate the derivative of comoving distance in `Mpc / h` with respect to
the distance modulus. The expression is valid for a flat universe over the
range of 0.00001 < z < 0.1.
"""
term1 = jnp.exp((0.443288 * mu) + (-14.286531))
dterm1 = 0.443288 * term1
term2 = (0.506973 * mu) + 12.954633
dterm2 = 0.506973
term3 = ((0.028134 * mu)**(((0.684713 * mu) + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) # noqa
ln_base = jnp.log(0.028134) + jnp.log(mu)
exponent = 0.835733 * mu + 1.235158 * Om0 - jnp.exp(0.072229 * mu)
exponent_derivative = 0.835733 - 0.072229 * jnp.exp(0.072229 * mu)
dterm3 = term3 * ((1 / mu) * exponent + exponent_derivative * ln_base)
term4 = (-0.045160) * mu
dterm4 = -0.045160
return (dterm1 * (term2 - (term3 - term4))
+ term1 * (dterm2 - (dterm3 - dterm4)))
def distmod2redshift(mu, Om0):
"""
Convert distance modulus to redshift, assuming `h = 1`. The expression is
valid for a flat universe over the range of 0.00001 < z < 0.1.
"""
return jnp.exp(((0.461108 * mu) - ((0.022187 * Om0) + (((0.022347 * mu)** (12.631788 - ((-6.708757) * Om0))) + 19.529852)))) # noqa
def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians): def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians):
"""Project the external velocity vector onto the line of sight.""" """Project the external velocity vector onto the line of sight."""
cos_dec = jnp.cos(dec_radians) cos_dec = jnp.cos(dec_radians)
@ -150,6 +185,37 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
return normal_logpdf(x, loc, scale) - jnp.log(norm) return normal_logpdf(x, loc, scale) - jnp.log(norm)
###############################################################################
# LOS interpolation #
###############################################################################
def interpolate_los(r, los, rgrid, method="cubic"):
"""
Interpolate the LOS field at a given radial distance.
Parameters
----------
r : 1-dimensional array of shape `(n_gal, )`
Radial distances at which to interpolate the LOS field.
los : 3-dimensional array of shape `(n_sims, n_gal, n_steps)`
LOS field.
rmin, rmax : float
Minimum and maximum radial distances in the data.
order : int, optional
The order of the interpolation. Default is 1, can be 0.
Returns
-------
2-dimensional array of shape `(n_sims, n_gal)`
"""
# Vectorize over the inner loop (ngal) first, then the outer loop (nsim)
def f(rn, los_row):
return interp1d(rn, rgrid, los_row, method=method)
return vmap(vmap(f, in_axes=(0, 0)), in_axes=(None, 0))(r, los)
############################################################################### ###############################################################################
# Base flow validation # # Base flow validation #
############################################################################### ###############################################################################
@ -291,6 +357,12 @@ class BaseFlowValidationModel(ABC):
return self._los_velocity return self._los_velocity
def log_los_density_at_r(self, r):
return interpolate_los(r, self.log_los_density(), self.r_xrange, )
def los_velocity_at_r(self, r):
return interpolate_los(r, self.los_velocity(), self.r_xrange, )
@abstractmethod @abstractmethod
def __call__(self, **kwargs): def __call__(self, **kwargs):
pass pass
@ -514,16 +586,16 @@ class PV_LogLikelihood(BaseFlowValidationModel):
Name of the catalogue. Name of the catalogue.
void_kwargs : dict, optional void_kwargs : dict, optional
Void data parameters. If `None` the data is not void data. Void data parameters. If `None` the data is not void data.
with_num_dist_marginalisation : bool, optional wo_num_dist_marginalisation : bool, optional
Whether to use numerical distance marginalisation, in which case Whether to directly sample the distance without numerical
the tracers cannot be coupled by a covariance matrix. By default marginalisation. in which case the tracers can be coupled by a
`True`. covariance matrix. By default `False`.
""" """
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs, def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
calibration_params, abs_calibration_params, mag_selection, calibration_params, abs_calibration_params, mag_selection,
r_xrange, Omega_m, kind, name, void_kwargs=None, r_xrange, Omega_m, kind, name, void_kwargs=None,
with_num_dist_marginalisation=True): wo_num_dist_marginalisation=False):
if e_zobs is not None: if e_zobs is not None:
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2) e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
else: else:
@ -549,7 +621,7 @@ class PV_LogLikelihood(BaseFlowValidationModel):
values += [jnp.log(los_density), los_velocity] values += [jnp.log(los_density), los_velocity]
# Density required only if not numerically marginalising. # Density required only if not numerically marginalising.
if not with_num_dist_marginalisation: if not wo_num_dist_marginalisation:
names += ["_los_density"] names += ["_los_density"]
values += [los_density] values += [los_density]
@ -561,12 +633,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
self.kind = kind self.kind = kind
self.name = name self.name = name
self.Omega_m = Omega_m self.Omega_m = Omega_m
self.with_num_dist_marginalisation = with_num_dist_marginalisation self.wo_num_dist_marginalisation = wo_num_dist_marginalisation
self.norm = - self.ndata * jnp.log(self.num_sims) self.norm = - self.ndata * jnp.log(self.num_sims)
# TODO: Somewhere here prepare the interpolators in case of no
# numerical marginalisation.
if mag_selection is not None: if mag_selection is not None:
self.mag_selection_kind = mag_selection["kind"] self.mag_selection_kind = mag_selection["kind"]
@ -772,25 +841,16 @@ class PV_LogLikelihood(BaseFlowValidationModel):
# 2. Log-likelihood of the true distance and observed redshifts. # 2. Log-likelihood of the true distance and observed redshifts.
# The marginalisation of the true distance can be done numerically. # The marginalisation of the true distance can be done numerically.
# ---------------------------------------------------------------- # ----------------------------------------------------------------
if self.with_num_dist_marginalisation: if not self.wo_num_dist_marginalisation:
if field_calibration_params["sample_h"]: if field_calibration_params["sample_h"]:
raise NotImplementedError("Sampling of h not implemented.") raise NotImplementedError(
# Rescale the grid to account for the sampled H0. For distance "Sampling of 'h' is not supported if numerically "
# modulus going from Mpc / h to Mpc implies larger numerical "marginalising the true distance.")
# values, so there has to be a minus sign since h < 1.
# mu_xrange = self.mu_xrange - 5 * jnp.log(h)
# The redshift should also be boosted since now the object are
# further away?
# Actually, the redshift ought to remain the same?
else:
mu_xrange = self.mu_xrange
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange) # Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
log_ptilde = log_ptilde_wo_bias( log_ptilde = log_ptilde_wo_bias(
mu_xrange[None, :], mu[:, None], e2_mu[:, None], self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
self.log_r2_xrange[None, :]) self.log_r2_xrange[None, :])
if self.is_void_data: if self.is_void_data:
@ -832,56 +892,52 @@ class PV_LogLikelihood(BaseFlowValidationModel):
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
else: else:
if field_calibration_params["sample_h"]: if field_calibration_params["sample_h"]:
raise NotImplementedError("Sampling of h not implemented.") raise NotImplementedError(
"Sampling of h is not yet implemented.")
raise NotImplementedError(
"Sampling of distance is not implemented. Work in progress.")
e_mu = jnp.sqrt(e2_mu) e_mu = jnp.sqrt(e2_mu)
# True distance modulus, shape is `(n_data)`` # True distance modulus, shape is `(n_data)``
with plate("plate_mu", self.ndata): with plate("plate_mu", self.ndata):
mu_true = sample("mu", Normal(mu, e_mu)) mu_true = sample("mu", Normal(mu, e_mu))
# True distance, shape is `(n_data)`` # True distance and redshift, shape is `(n_data)`.
r_true = distmod2dist(mu_true, self.Omega_m) r_true = distmod2dist(mu_true, self.Omega_m)
# TODO: z_true = distmod2redshift(mu_true, self.Omega_m)
z_true = None
if self.is_void_data: if self.is_void_data:
raise NotImplementedError( raise NotImplementedError(
"Void data not implemented yet for distance sampling.") "Void data not implemented yet for distance sampling.")
else: else:
# grid log(density), shape is `(n_sims, n_data, n_rad)` # Grid log(density), shape is `(n_sims, n_data, n_rad)`
log_los_density_grid = self.los_density() log_los_density_grid = self.log_los_density()
# TODO: Need to add the interpolators for these
# Densities and velocities at the true distances, shape is # Densities and velocities at the true distances, shape is
# `(n_sims, n_data)` # `(n_sims, n_data)`
log_density = None log_density = self.log_los_density_at_r(r_true)
los_velocity = None los_velocity = self.los_velocity_at_r(r_true)
alpha = distmod_params["alpha"] alpha = distmod_params["alpha"]
# Check dimensions of all this
# Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)` # Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)`
pnorm = ( pnorm = (
self.log_r2_xrange[None, None, :] + self.log_r2_xrange[None, None, :]
+ alpha * log_los_density_grid + alpha * log_los_density_grid
+ normal_logpdf( + normal_logpdf(
self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa
pnorm = jnp.exp(pnorm) pnorm = jnp.exp(pnorm)
# Now integrate over the radial steps. Shape is `(nsims, ndata)`.
# Normalization of p(mu). Shape is now (nsims, ndata) # No Jacobian here because I integrate over distance, not the
# distance modulus.
pnorm = simpson(pnorm, x=self.r_xrange, axis=-1) pnorm = simpson(pnorm, x=self.r_xrange, axis=-1)
# TODO: There should be a Jacobian? # Jacobian |dr / dmu|_(mu_true), shape is `(n_data)`.
jac = jnp.abs(distmod2dist_gradient(mu_true, self.Omega_m))
# Calculate unnormalized log p(mu). Shape is (nsims, ndata) # Calculate unnormalized log p(mu). Shape is (nsims, ndata)
ll = ( ll = (
2 * (jnp.log(r_true) - self.log_r2_xrange_mean)[None, :] + jnp.log(jac)[None, :]
+ (2 * jnp.log(r_true) - self.log_r2_xrange_mean)[None, :]
+ alpha * log_density + alpha * log_density
+ normal_logpdf(mu_true, mu, e_mu)[None, :]) )
# Subtract the normalization. Shape remains (nsims, ndata) # Subtract the normalization. Shape remains (nsims, ndata)
ll -= jnp.log(pnorm) ll -= jnp.log(pnorm)

View file

@ -88,7 +88,7 @@ def print_variables(names, variables):
def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs, def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
verbose=True): wo_num_dist_marginalisation, verbose=True):
"""Load the data and create the NumPyro models.""" """Load the data and create the NumPyro models."""
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring) paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
folder = "/mnt/extraspace/rstiskalek/catalogs/" folder = "/mnt/extraspace/rstiskalek/catalogs/"
@ -128,6 +128,7 @@ def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
ksmooth=ARGS.ksmooth) ksmooth=ARGS.ksmooth)
models[i] = csiborgtools.flow.get_model( models[i] = csiborgtools.flow.get_model(
loader, mag_selection=mag_selection[i], void_kwargs=void_kwargs, loader, mag_selection=mag_selection[i], void_kwargs=void_kwargs,
wo_num_dist_marginalisation=wo_num_dist_marginalisation,
**get_model_kwargs) **get_model_kwargs)
fprint(f"num. radial steps is {len(loader.rdist)}") fprint(f"num. radial steps is {len(loader.rdist)}")
@ -299,10 +300,10 @@ if __name__ == "__main__":
########################################################################### ###########################################################################
# `None` means default behaviour # `None` means default behaviour
nsteps = 10_000 nsteps = 1_000
nburn = 2_000 nburn = 1_000
zcmb_min = None zcmb_min = None
zcmb_max = 0.05 zcmb_max = 0.04
nchains_harmonic = 10 nchains_harmonic = 10
num_epochs = 50 num_epochs = 50
inference_method = "mike" inference_method = "mike"
@ -313,8 +314,9 @@ if __name__ == "__main__":
sample_Vmag_vax = False sample_Vmag_vax = False
sample_Vmono = False sample_Vmono = False
sample_mag_dipole = False sample_mag_dipole = False
wo_num_dist_marginalisation = True
absolute_calibration = None absolute_calibration = None
calculate_harmonic = False if inference_method == "bayes" else True calculate_harmonic = (False if inference_method == "bayes" else True) and (not wo_num_dist_marginalisation) # noqa
sample_h = True if absolute_calibration is not None else False sample_h = True if absolute_calibration is not None else False
fname_kwargs = {"inference_method": inference_method, fname_kwargs = {"inference_method": inference_method,
@ -341,6 +343,7 @@ if __name__ == "__main__":
"num_epochs": num_epochs, "num_epochs": num_epochs,
"inference_method": inference_method, "inference_method": inference_method,
"sample_mag_dipole": sample_mag_dipole, "sample_mag_dipole": sample_mag_dipole,
"wo_dist_marg": wo_num_dist_marginalisation,
"absolute_calibration": absolute_calibration, "absolute_calibration": absolute_calibration,
"sample_h": sample_h, "sample_h": sample_h,
} }
@ -420,7 +423,8 @@ if __name__ == "__main__":
print(f"{'Current simulation:':<20} {i + 1} ({ksim}) out of {len(ksim_iterator)}.") # noqa print(f"{'Current simulation:':<20} {i + 1} ({ksim}) out of {len(ksim_iterator)}.") # noqa
fname_kwargs["nsim"] = ksim fname_kwargs["nsim"] = ksim
models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs) models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
wo_num_dist_marginalisation)
model_kwargs = { model_kwargs = {
"models": models, "models": models,
"field_calibration_hyperparams": calibration_hyperparams, "field_calibration_hyperparams": calibration_hyperparams,