mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-07-06 22:41:11 +00:00
Add mu sampling
This commit is contained in:
parent
05123ec868
commit
b99bf54789
2 changed files with 112 additions and 52 deletions
|
@ -27,8 +27,10 @@ import numpy as np
|
||||||
from astropy import units as u
|
from astropy import units as u
|
||||||
from astropy.coordinates import SkyCoord, angular_separation
|
from astropy.coordinates import SkyCoord, angular_separation
|
||||||
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
from astropy.cosmology import FlatLambdaCDM, z_at_value
|
||||||
|
from interpax import interp1d
|
||||||
from jax import jit
|
from jax import jit
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
|
from jax import vmap
|
||||||
from jax.scipy.special import erf, logsumexp
|
from jax.scipy.special import erf, logsumexp
|
||||||
from numpyro import factor, plate, sample
|
from numpyro import factor, plate, sample
|
||||||
from numpyro.distributions import MultivariateNormal, Normal, Uniform
|
from numpyro.distributions import MultivariateNormal, Normal, Uniform
|
||||||
|
@ -88,6 +90,39 @@ def distmod2dist(mu, Om0):
|
||||||
return (-0.000301) + (term1 * (term2 - (term3 - term4)))
|
return (-0.000301) + (term1 * (term2 - (term3 - term4)))
|
||||||
|
|
||||||
|
|
||||||
|
def distmod2dist_gradient(mu, Om0):
|
||||||
|
"""
|
||||||
|
Calculate the derivative of comoving distance in `Mpc / h` with respect to
|
||||||
|
the distance modulus. The expression is valid for a flat universe over the
|
||||||
|
range of 0.00001 < z < 0.1.
|
||||||
|
"""
|
||||||
|
term1 = jnp.exp((0.443288 * mu) + (-14.286531))
|
||||||
|
dterm1 = 0.443288 * term1
|
||||||
|
|
||||||
|
term2 = (0.506973 * mu) + 12.954633
|
||||||
|
dterm2 = 0.506973
|
||||||
|
|
||||||
|
term3 = ((0.028134 * mu)**(((0.684713 * mu) + ((0.151020 * mu) + (1.235158 * Om0))) - jnp.exp(0.072229 * mu))) # noqa
|
||||||
|
ln_base = jnp.log(0.028134) + jnp.log(mu)
|
||||||
|
exponent = 0.835733 * mu + 1.235158 * Om0 - jnp.exp(0.072229 * mu)
|
||||||
|
exponent_derivative = 0.835733 - 0.072229 * jnp.exp(0.072229 * mu)
|
||||||
|
dterm3 = term3 * ((1 / mu) * exponent + exponent_derivative * ln_base)
|
||||||
|
|
||||||
|
term4 = (-0.045160) * mu
|
||||||
|
dterm4 = -0.045160
|
||||||
|
|
||||||
|
return (dterm1 * (term2 - (term3 - term4))
|
||||||
|
+ term1 * (dterm2 - (dterm3 - dterm4)))
|
||||||
|
|
||||||
|
|
||||||
|
def distmod2redshift(mu, Om0):
|
||||||
|
"""
|
||||||
|
Convert distance modulus to redshift, assuming `h = 1`. The expression is
|
||||||
|
valid for a flat universe over the range of 0.00001 < z < 0.1.
|
||||||
|
"""
|
||||||
|
return jnp.exp(((0.461108 * mu) - ((0.022187 * Om0) + (((0.022347 * mu)** (12.631788 - ((-6.708757) * Om0))) + 19.529852)))) # noqa
|
||||||
|
|
||||||
|
|
||||||
def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians):
|
def project_Vext(Vext_x, Vext_y, Vext_z, RA_radians, dec_radians):
|
||||||
"""Project the external velocity vector onto the line of sight."""
|
"""Project the external velocity vector onto the line of sight."""
|
||||||
cos_dec = jnp.cos(dec_radians)
|
cos_dec = jnp.cos(dec_radians)
|
||||||
|
@ -150,6 +185,37 @@ def upper_truncated_normal_logpdf(x, loc, scale, xmax):
|
||||||
return normal_logpdf(x, loc, scale) - jnp.log(norm)
|
return normal_logpdf(x, loc, scale) - jnp.log(norm)
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# LOS interpolation #
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate_los(r, los, rgrid, method="cubic"):
|
||||||
|
"""
|
||||||
|
Interpolate the LOS field at a given radial distance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
r : 1-dimensional array of shape `(n_gal, )`
|
||||||
|
Radial distances at which to interpolate the LOS field.
|
||||||
|
los : 3-dimensional array of shape `(n_sims, n_gal, n_steps)`
|
||||||
|
LOS field.
|
||||||
|
rmin, rmax : float
|
||||||
|
Minimum and maximum radial distances in the data.
|
||||||
|
order : int, optional
|
||||||
|
The order of the interpolation. Default is 1, can be 0.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
2-dimensional array of shape `(n_sims, n_gal)`
|
||||||
|
"""
|
||||||
|
# Vectorize over the inner loop (ngal) first, then the outer loop (nsim)
|
||||||
|
def f(rn, los_row):
|
||||||
|
return interp1d(rn, rgrid, los_row, method=method)
|
||||||
|
|
||||||
|
return vmap(vmap(f, in_axes=(0, 0)), in_axes=(None, 0))(r, los)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Base flow validation #
|
# Base flow validation #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
@ -291,6 +357,12 @@ class BaseFlowValidationModel(ABC):
|
||||||
|
|
||||||
return self._los_velocity
|
return self._los_velocity
|
||||||
|
|
||||||
|
def log_los_density_at_r(self, r):
|
||||||
|
return interpolate_los(r, self.log_los_density(), self.r_xrange, )
|
||||||
|
|
||||||
|
def los_velocity_at_r(self, r):
|
||||||
|
return interpolate_los(r, self.los_velocity(), self.r_xrange, )
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
@ -514,16 +586,16 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
Name of the catalogue.
|
Name of the catalogue.
|
||||||
void_kwargs : dict, optional
|
void_kwargs : dict, optional
|
||||||
Void data parameters. If `None` the data is not void data.
|
Void data parameters. If `None` the data is not void data.
|
||||||
with_num_dist_marginalisation : bool, optional
|
wo_num_dist_marginalisation : bool, optional
|
||||||
Whether to use numerical distance marginalisation, in which case
|
Whether to directly sample the distance without numerical
|
||||||
the tracers cannot be coupled by a covariance matrix. By default
|
marginalisation. in which case the tracers can be coupled by a
|
||||||
`True`.
|
covariance matrix. By default `False`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
|
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
|
||||||
calibration_params, abs_calibration_params, mag_selection,
|
calibration_params, abs_calibration_params, mag_selection,
|
||||||
r_xrange, Omega_m, kind, name, void_kwargs=None,
|
r_xrange, Omega_m, kind, name, void_kwargs=None,
|
||||||
with_num_dist_marginalisation=True):
|
wo_num_dist_marginalisation=False):
|
||||||
if e_zobs is not None:
|
if e_zobs is not None:
|
||||||
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
||||||
else:
|
else:
|
||||||
|
@ -549,7 +621,7 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
values += [jnp.log(los_density), los_velocity]
|
values += [jnp.log(los_density), los_velocity]
|
||||||
|
|
||||||
# Density required only if not numerically marginalising.
|
# Density required only if not numerically marginalising.
|
||||||
if not with_num_dist_marginalisation:
|
if not wo_num_dist_marginalisation:
|
||||||
names += ["_los_density"]
|
names += ["_los_density"]
|
||||||
values += [los_density]
|
values += [los_density]
|
||||||
|
|
||||||
|
@ -561,12 +633,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
self.kind = kind
|
self.kind = kind
|
||||||
self.name = name
|
self.name = name
|
||||||
self.Omega_m = Omega_m
|
self.Omega_m = Omega_m
|
||||||
self.with_num_dist_marginalisation = with_num_dist_marginalisation
|
self.wo_num_dist_marginalisation = wo_num_dist_marginalisation
|
||||||
self.norm = - self.ndata * jnp.log(self.num_sims)
|
self.norm = - self.ndata * jnp.log(self.num_sims)
|
||||||
|
|
||||||
# TODO: Somewhere here prepare the interpolators in case of no
|
|
||||||
# numerical marginalisation.
|
|
||||||
|
|
||||||
if mag_selection is not None:
|
if mag_selection is not None:
|
||||||
self.mag_selection_kind = mag_selection["kind"]
|
self.mag_selection_kind = mag_selection["kind"]
|
||||||
|
|
||||||
|
@ -772,25 +841,16 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
# 2. Log-likelihood of the true distance and observed redshifts.
|
# 2. Log-likelihood of the true distance and observed redshifts.
|
||||||
# The marginalisation of the true distance can be done numerically.
|
# The marginalisation of the true distance can be done numerically.
|
||||||
# ----------------------------------------------------------------
|
# ----------------------------------------------------------------
|
||||||
if self.with_num_dist_marginalisation:
|
if not self.wo_num_dist_marginalisation:
|
||||||
|
|
||||||
if field_calibration_params["sample_h"]:
|
if field_calibration_params["sample_h"]:
|
||||||
raise NotImplementedError("Sampling of h not implemented.")
|
raise NotImplementedError(
|
||||||
# Rescale the grid to account for the sampled H0. For distance
|
"Sampling of 'h' is not supported if numerically "
|
||||||
# modulus going from Mpc / h to Mpc implies larger numerical
|
"marginalising the true distance.")
|
||||||
# values, so there has to be a minus sign since h < 1.
|
|
||||||
# mu_xrange = self.mu_xrange - 5 * jnp.log(h)
|
|
||||||
|
|
||||||
# The redshift should also be boosted since now the object are
|
|
||||||
# further away?
|
|
||||||
|
|
||||||
# Actually, the redshift ought to remain the same?
|
|
||||||
else:
|
|
||||||
mu_xrange = self.mu_xrange
|
|
||||||
|
|
||||||
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
|
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
|
||||||
log_ptilde = log_ptilde_wo_bias(
|
log_ptilde = log_ptilde_wo_bias(
|
||||||
mu_xrange[None, :], mu[:, None], e2_mu[:, None],
|
self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
|
||||||
self.log_r2_xrange[None, :])
|
self.log_r2_xrange[None, :])
|
||||||
|
|
||||||
if self.is_void_data:
|
if self.is_void_data:
|
||||||
|
@ -832,56 +892,52 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
||||||
else:
|
else:
|
||||||
if field_calibration_params["sample_h"]:
|
if field_calibration_params["sample_h"]:
|
||||||
raise NotImplementedError("Sampling of h not implemented.")
|
raise NotImplementedError(
|
||||||
|
"Sampling of h is not yet implemented.")
|
||||||
raise NotImplementedError(
|
|
||||||
"Sampling of distance is not implemented. Work in progress.")
|
|
||||||
|
|
||||||
e_mu = jnp.sqrt(e2_mu)
|
e_mu = jnp.sqrt(e2_mu)
|
||||||
# True distance modulus, shape is `(n_data)``
|
# True distance modulus, shape is `(n_data)``
|
||||||
with plate("plate_mu", self.ndata):
|
with plate("plate_mu", self.ndata):
|
||||||
mu_true = sample("mu", Normal(mu, e_mu))
|
mu_true = sample("mu", Normal(mu, e_mu))
|
||||||
|
|
||||||
# True distance, shape is `(n_data)``
|
# True distance and redshift, shape is `(n_data)`.
|
||||||
r_true = distmod2dist(mu_true, self.Omega_m)
|
r_true = distmod2dist(mu_true, self.Omega_m)
|
||||||
# TODO:
|
z_true = distmod2redshift(mu_true, self.Omega_m)
|
||||||
z_true = None
|
|
||||||
|
|
||||||
if self.is_void_data:
|
if self.is_void_data:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Void data not implemented yet for distance sampling.")
|
"Void data not implemented yet for distance sampling.")
|
||||||
else:
|
else:
|
||||||
# grid log(density), shape is `(n_sims, n_data, n_rad)`
|
# Grid log(density), shape is `(n_sims, n_data, n_rad)`
|
||||||
log_los_density_grid = self.los_density()
|
log_los_density_grid = self.log_los_density()
|
||||||
|
|
||||||
# TODO: Need to add the interpolators for these
|
|
||||||
# Densities and velocities at the true distances, shape is
|
# Densities and velocities at the true distances, shape is
|
||||||
# `(n_sims, n_data)`
|
# `(n_sims, n_data)`
|
||||||
log_density = None
|
log_density = self.log_los_density_at_r(r_true)
|
||||||
los_velocity = None
|
los_velocity = self.los_velocity_at_r(r_true)
|
||||||
|
|
||||||
alpha = distmod_params["alpha"]
|
alpha = distmod_params["alpha"]
|
||||||
|
|
||||||
# Check dimensions of all this
|
|
||||||
|
|
||||||
# Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)`
|
# Normalisation of p(mu), shape is `(n_sims, n_data, n_rad)`
|
||||||
pnorm = (
|
pnorm = (
|
||||||
self.log_r2_xrange[None, None, :]
|
+ self.log_r2_xrange[None, None, :]
|
||||||
+ alpha * log_los_density_grid
|
+ alpha * log_los_density_grid
|
||||||
+ normal_logpdf(
|
+ normal_logpdf(
|
||||||
self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa
|
self.mu_xrange[None, :], mu[:, None], e_mu[:, None])[None, ...]) # noqa
|
||||||
|
|
||||||
pnorm = jnp.exp(pnorm)
|
pnorm = jnp.exp(pnorm)
|
||||||
|
# Now integrate over the radial steps. Shape is `(nsims, ndata)`.
|
||||||
# Normalization of p(mu). Shape is now (nsims, ndata)
|
# No Jacobian here because I integrate over distance, not the
|
||||||
|
# distance modulus.
|
||||||
pnorm = simpson(pnorm, x=self.r_xrange, axis=-1)
|
pnorm = simpson(pnorm, x=self.r_xrange, axis=-1)
|
||||||
|
|
||||||
# TODO: There should be a Jacobian?
|
# Jacobian |dr / dmu|_(mu_true), shape is `(n_data)`.
|
||||||
|
jac = jnp.abs(distmod2dist_gradient(mu_true, self.Omega_m))
|
||||||
|
|
||||||
# Calculate unnormalized log p(mu). Shape is (nsims, ndata)
|
# Calculate unnormalized log p(mu). Shape is (nsims, ndata)
|
||||||
ll = (
|
ll = (
|
||||||
2 * (jnp.log(r_true) - self.log_r2_xrange_mean)[None, :]
|
+ jnp.log(jac)[None, :]
|
||||||
|
+ (2 * jnp.log(r_true) - self.log_r2_xrange_mean)[None, :]
|
||||||
+ alpha * log_density
|
+ alpha * log_density
|
||||||
+ normal_logpdf(mu_true, mu, e_mu)[None, :])
|
)
|
||||||
|
|
||||||
# Subtract the normalization. Shape remains (nsims, ndata)
|
# Subtract the normalization. Shape remains (nsims, ndata)
|
||||||
ll -= jnp.log(pnorm)
|
ll -= jnp.log(pnorm)
|
||||||
|
|
|
@ -88,7 +88,7 @@ def print_variables(names, variables):
|
||||||
|
|
||||||
|
|
||||||
def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
|
def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
|
||||||
verbose=True):
|
wo_num_dist_marginalisation, verbose=True):
|
||||||
"""Load the data and create the NumPyro models."""
|
"""Load the data and create the NumPyro models."""
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
folder = "/mnt/extraspace/rstiskalek/catalogs/"
|
||||||
|
@ -128,6 +128,7 @@ def get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
|
||||||
ksmooth=ARGS.ksmooth)
|
ksmooth=ARGS.ksmooth)
|
||||||
models[i] = csiborgtools.flow.get_model(
|
models[i] = csiborgtools.flow.get_model(
|
||||||
loader, mag_selection=mag_selection[i], void_kwargs=void_kwargs,
|
loader, mag_selection=mag_selection[i], void_kwargs=void_kwargs,
|
||||||
|
wo_num_dist_marginalisation=wo_num_dist_marginalisation,
|
||||||
**get_model_kwargs)
|
**get_model_kwargs)
|
||||||
|
|
||||||
fprint(f"num. radial steps is {len(loader.rdist)}")
|
fprint(f"num. radial steps is {len(loader.rdist)}")
|
||||||
|
@ -299,10 +300,10 @@ if __name__ == "__main__":
|
||||||
###########################################################################
|
###########################################################################
|
||||||
|
|
||||||
# `None` means default behaviour
|
# `None` means default behaviour
|
||||||
nsteps = 10_000
|
nsteps = 1_000
|
||||||
nburn = 2_000
|
nburn = 1_000
|
||||||
zcmb_min = None
|
zcmb_min = None
|
||||||
zcmb_max = 0.05
|
zcmb_max = 0.04
|
||||||
nchains_harmonic = 10
|
nchains_harmonic = 10
|
||||||
num_epochs = 50
|
num_epochs = 50
|
||||||
inference_method = "mike"
|
inference_method = "mike"
|
||||||
|
@ -313,8 +314,9 @@ if __name__ == "__main__":
|
||||||
sample_Vmag_vax = False
|
sample_Vmag_vax = False
|
||||||
sample_Vmono = False
|
sample_Vmono = False
|
||||||
sample_mag_dipole = False
|
sample_mag_dipole = False
|
||||||
|
wo_num_dist_marginalisation = True
|
||||||
absolute_calibration = None
|
absolute_calibration = None
|
||||||
calculate_harmonic = False if inference_method == "bayes" else True
|
calculate_harmonic = (False if inference_method == "bayes" else True) and (not wo_num_dist_marginalisation) # noqa
|
||||||
sample_h = True if absolute_calibration is not None else False
|
sample_h = True if absolute_calibration is not None else False
|
||||||
|
|
||||||
fname_kwargs = {"inference_method": inference_method,
|
fname_kwargs = {"inference_method": inference_method,
|
||||||
|
@ -341,6 +343,7 @@ if __name__ == "__main__":
|
||||||
"num_epochs": num_epochs,
|
"num_epochs": num_epochs,
|
||||||
"inference_method": inference_method,
|
"inference_method": inference_method,
|
||||||
"sample_mag_dipole": sample_mag_dipole,
|
"sample_mag_dipole": sample_mag_dipole,
|
||||||
|
"wo_dist_marg": wo_num_dist_marginalisation,
|
||||||
"absolute_calibration": absolute_calibration,
|
"absolute_calibration": absolute_calibration,
|
||||||
"sample_h": sample_h,
|
"sample_h": sample_h,
|
||||||
}
|
}
|
||||||
|
@ -420,7 +423,8 @@ if __name__ == "__main__":
|
||||||
print(f"{'Current simulation:':<20} {i + 1} ({ksim}) out of {len(ksim_iterator)}.") # noqa
|
print(f"{'Current simulation:':<20} {i + 1} ({ksim}) out of {len(ksim_iterator)}.") # noqa
|
||||||
|
|
||||||
fname_kwargs["nsim"] = ksim
|
fname_kwargs["nsim"] = ksim
|
||||||
models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs)
|
models = get_models(ksim, get_model_kwargs, mag_selection, void_kwargs,
|
||||||
|
wo_num_dist_marginalisation)
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
"models": models,
|
"models": models,
|
||||||
"field_calibration_hyperparams": calibration_hyperparams,
|
"field_calibration_hyperparams": calibration_hyperparams,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue