mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 20:28:02 +00:00
Absmag (#149)
* Add spacing * Add various cleanups of the code * Add basic updates * Add
This commit is contained in:
parent
4fa0e04f6e
commit
5336c0296c
1 changed files with 75 additions and 60 deletions
|
@ -45,13 +45,13 @@ H0 = 100 # km / s / Mpc
|
||||||
# JAX Flow model #
|
# JAX Flow model #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
def dist2redshift(dist, Omega_m):
|
def dist2redshift(dist, Omega_m, h=1.):
|
||||||
"""
|
"""
|
||||||
Convert comoving distance to cosmological redshift if the Universe is
|
Convert comoving distance to cosmological redshift if the Universe is
|
||||||
flat and z << 1.
|
flat and z << 1.
|
||||||
"""
|
"""
|
||||||
eta = 3 * Omega_m / 2
|
eta = 3 * Omega_m / 2
|
||||||
return 1 / eta * (1 - (1 - 2 * H0 * dist / SPEED_OF_LIGHT * eta)**0.5)
|
return 1 / eta * (1 - (1 - 2 * 100 * h * dist / SPEED_OF_LIGHT * eta)**0.5)
|
||||||
|
|
||||||
|
|
||||||
def redshift2dist(z, Omega_m):
|
def redshift2dist(z, Omega_m):
|
||||||
|
@ -366,11 +366,6 @@ def sample_calibration(Vext_min, Vext_max, Vmono_min, Vmono_max, beta_min,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
# PV calibration model #
|
|
||||||
###############################################################################
|
|
||||||
|
|
||||||
|
|
||||||
def sample_gaussian_hyperprior(param, name, xmin, xmax):
|
def sample_gaussian_hyperprior(param, name, xmin, xmax):
|
||||||
"""Sample MNR Gaussian hyperprior mean and standard deviation."""
|
"""Sample MNR Gaussian hyperprior mean and standard deviation."""
|
||||||
mean = sample(f"{param}_mean_{name}", Uniform(xmin, xmax))
|
mean = sample(f"{param}_mean_{name}", Uniform(xmin, xmax))
|
||||||
|
@ -378,9 +373,15 @@ def sample_gaussian_hyperprior(param, name, xmin, xmax):
|
||||||
return mean, std
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# PV calibration model without absolute calibration #
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
class PV_LogLikelihood(BaseFlowValidationModel):
|
class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
"""
|
"""
|
||||||
Peculiar velocity validation model log-likelihood.
|
Peculiar velocity validation model log-likelihood with numerical
|
||||||
|
integration of the true distances.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -408,11 +409,15 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
Catalogue kind, either "TFR", "SN", or "simple".
|
Catalogue kind, either "TFR", "SN", or "simple".
|
||||||
name : str
|
name : str
|
||||||
Name of the catalogue.
|
Name of the catalogue.
|
||||||
|
with_num_dist_marginalisation : bool, optional
|
||||||
|
Whether to use numerical distance marginalisation, in which case
|
||||||
|
the tracers cannot be coupled by a covariance matrix. By default
|
||||||
|
`True`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
|
def __init__(self, los_density, los_velocity, RA, dec, z_obs, e_zobs,
|
||||||
calibration_params, abs_calibration_params, mag_selection,
|
calibration_params, abs_calibration_params, mag_selection,
|
||||||
r_xrange, Omega_m, kind, name):
|
r_xrange, Omega_m, kind, name, with_num_dist_marginalisation):
|
||||||
if e_zobs is not None:
|
if e_zobs is not None:
|
||||||
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
e2_cz_obs = jnp.asarray((SPEED_OF_LIGHT * e_zobs)**2)
|
||||||
else:
|
else:
|
||||||
|
@ -433,8 +438,12 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
self.kind = kind
|
self.kind = kind
|
||||||
self.name = name
|
self.name = name
|
||||||
self.Omega_m = Omega_m
|
self.Omega_m = Omega_m
|
||||||
|
self.with_num_dist_marginalisation = with_num_dist_marginalisation
|
||||||
self.norm = - self.ndata * jnp.log(self.num_sims)
|
self.norm = - self.ndata * jnp.log(self.num_sims)
|
||||||
|
|
||||||
|
# TODO: Somewhere here prepare the interpolators in case of no
|
||||||
|
# numerical marginalisation.
|
||||||
|
|
||||||
if mag_selection is not None:
|
if mag_selection is not None:
|
||||||
self.mag_selection_kind = mag_selection["kind"]
|
self.mag_selection_kind = mag_selection["kind"]
|
||||||
|
|
||||||
|
@ -485,6 +494,9 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
Vmono = field_calibration_params["Vmono"]
|
Vmono = field_calibration_params["Vmono"]
|
||||||
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
|
Vext_rad = project_Vext(Vext[0], Vext[1], Vext[2], self.RA, self.dec)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# 1. Sample true observables and obtain the distance estimate
|
||||||
|
# ------------------------------------------------------------
|
||||||
e_mu = distmod_params["e_mu"]
|
e_mu = distmod_params["e_mu"]
|
||||||
if self.kind == "SN":
|
if self.kind == "SN":
|
||||||
mag_cal = distmod_params["mag_cal"]
|
mag_cal = distmod_params["mag_cal"]
|
||||||
|
@ -532,10 +544,6 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
|
|
||||||
mu = distmod_SN(
|
mu = distmod_SN(
|
||||||
mag_true, x1_true, c_true, mag_cal, alpha_cal, beta_cal)
|
mag_true, x1_true, c_true, mag_cal, alpha_cal, beta_cal)
|
||||||
|
|
||||||
if field_calibration_params["sample_h"]:
|
|
||||||
raise NotImplementedError("H0 for SN not implemented.")
|
|
||||||
|
|
||||||
elif self.kind == "TFR":
|
elif self.kind == "TFR":
|
||||||
a = distmod_params["a"]
|
a = distmod_params["a"]
|
||||||
b = distmod_params["b"]
|
b = distmod_params["b"]
|
||||||
|
@ -605,11 +613,6 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
e2_mu = jnp.ones_like(mag_true) * e_mu**2
|
e2_mu = jnp.ones_like(mag_true) * e_mu**2
|
||||||
|
|
||||||
mu = distmod_TFR(mag_true, eta_true, a, b, c)
|
mu = distmod_TFR(mag_true, eta_true, a, b, c)
|
||||||
|
|
||||||
if field_calibration_params["sample_h"]:
|
|
||||||
raise NotImplementedError("H0 for TFR not implemented.")
|
|
||||||
# mu -= 5 * jnp.log10(field_calibration_params["h"])
|
|
||||||
|
|
||||||
elif self.kind == "simple":
|
elif self.kind == "simple":
|
||||||
dmu = distmod_params["dmu"]
|
dmu = distmod_params["dmu"]
|
||||||
|
|
||||||
|
@ -628,61 +631,73 @@ class PV_LogLikelihood(BaseFlowValidationModel):
|
||||||
e2_mu = jnp.ones_like(mag_true) * e_mu**2
|
e2_mu = jnp.ones_like(mag_true) * e_mu**2
|
||||||
|
|
||||||
mu = mu_true + dmu
|
mu = mu_true + dmu
|
||||||
|
|
||||||
if field_calibration_params["sample_h"]:
|
|
||||||
raise NotImplementedError("H0 for simple not implemented.")
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown kind: `{self.kind}`.")
|
raise ValueError(f"Unknown kind: `{self.kind}`.")
|
||||||
|
|
||||||
|
# h = field_calibration_params["h"]
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
# 2. Log-likelihood of the true distance and observed redshifts.
|
||||||
|
# The marginalisation of the true distance can be done numerically.
|
||||||
|
# ----------------------------------------------------------------
|
||||||
|
if self.with_num_dist_marginalisation:
|
||||||
|
|
||||||
|
if field_calibration_params["sample_h"]:
|
||||||
|
raise NotImplementedError("Sampling of h not implemented.")
|
||||||
|
# Rescale the grid to account for the sampled H0. For distance
|
||||||
|
# modulus going from Mpc / h to Mpc implies larger numerical
|
||||||
|
# values, so there has to be a minus sign since h < 1.
|
||||||
|
# mu_xrange = self.mu_xrange - 5 * jnp.log(h)
|
||||||
|
|
||||||
|
# The redshift should also be boosted since now the object are
|
||||||
|
# further away?
|
||||||
|
|
||||||
|
# Actually, the redshift ought to remain the same?
|
||||||
|
else:
|
||||||
|
mu_xrange = self.mu_xrange
|
||||||
|
|
||||||
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
|
# Calculate p(r) (Malmquist bias). Shape is (ndata, nxrange)
|
||||||
log_ptilde = log_ptilde_wo_bias(
|
log_ptilde = log_ptilde_wo_bias(
|
||||||
self.mu_xrange[None, :], mu[:, None], e2_mu[:, None],
|
mu_xrange[None, :], mu[:, None], e2_mu[:, None],
|
||||||
self.log_r2_xrange[None, :])
|
self.log_r2_xrange[None, :])
|
||||||
|
|
||||||
# Inhomogeneous Malmquist bias. Shape is (n_sims, ndata, nxrange)
|
# Inhomogeneous Malmquist bias. Shape: (nsims, ndata, nxrange)
|
||||||
alpha = distmod_params["alpha"]
|
alpha = distmod_params["alpha"]
|
||||||
log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density
|
log_ptilde = log_ptilde[None, ...] + alpha * self.log_los_density
|
||||||
|
|
||||||
ptilde = jnp.exp(log_ptilde)
|
ptilde = jnp.exp(log_ptilde)
|
||||||
|
|
||||||
# Normalization of p(r). Shape is (n_sims, ndata)
|
# Normalization of p(r). Shape: (nsims, ndata)
|
||||||
pnorm = simpson(ptilde, x=self.r_xrange, axis=-1)
|
pnorm = simpson(ptilde, x=self.r_xrange, axis=-1)
|
||||||
|
|
||||||
# Calculate z_obs at each distance. Shape is (n_sims, ndata, nxrange)
|
# Calculate z_obs at each distance. Shape: (nsims, ndata, nxrange)
|
||||||
vrad = field_calibration_params["beta"] * self.los_velocity
|
vrad = field_calibration_params["beta"] * self.los_velocity
|
||||||
vrad += (Vext_rad[None, :, None] + Vmono)
|
vrad += (Vext_rad[None, :, None] + Vmono)
|
||||||
zobs = (1 + self.z_xrange[None, None, :]) * (1 + vrad / SPEED_OF_LIGHT)
|
zobs = 1 + self.z_xrange[None, None, :]
|
||||||
|
zobs *= 1 + vrad / SPEED_OF_LIGHT
|
||||||
zobs -= 1.
|
zobs -= 1.
|
||||||
|
|
||||||
# Shape remains (n_sims, ndata, nxrange)
|
# Shape remains (nsims, ndata, nxrange)
|
||||||
ptilde *= likelihood_zobs(
|
ptilde *= likelihood_zobs(
|
||||||
self.z_obs[None, :, None], zobs, e2_cz[None, :, None])
|
self.z_obs[None, :, None], zobs, e2_cz[None, :, None])
|
||||||
|
|
||||||
if self.with_absolute_calibration:
|
if self.with_absolute_calibration:
|
||||||
raise NotImplementedError("Absolute calibration not implemented.")
|
raise NotImplementedError(
|
||||||
# Absolute calibration likelihood, the shape is now
|
"Absolute calibration not implemented for this model. "
|
||||||
# (ndata_with_calibration, ncalib, nxrange)
|
"Use `PV_LogLikelihood_NoDistMarg` instead.")
|
||||||
# ll_calibration = normal_logpdf(
|
|
||||||
# self.mu_xrange[None, None, :],
|
|
||||||
# self.calibration_distmod[..., None],
|
|
||||||
# self.calibration_edistmod[..., None])
|
|
||||||
|
|
||||||
# # Average the likelihood over the calibration points. The shape
|
# Integrate over the radial distance. Shape: (nsims, ndata)
|
||||||
# is
|
|
||||||
# # now (ndata, nxrange)
|
|
||||||
# ll_calibration = logsumexp(
|
|
||||||
# jnp.nan_to_num(ll_calibration, nan=-jnp.inf), axis=1)
|
|
||||||
# # This is the normalisation because we want the *average*.
|
|
||||||
# ll_calibration -= self.log_length_calibration[:, None]
|
|
||||||
|
|
||||||
# ptilde = ptilde.at[:, self.data_with_calibration, :].
|
|
||||||
# multiply(jnp.exp(ll_calibration))
|
|
||||||
|
|
||||||
# Integrate over the radial distance. Shape is (n_sims, ndata)
|
|
||||||
ll = jnp.log(simpson(ptilde, x=self.r_xrange, axis=-1))
|
ll = jnp.log(simpson(ptilde, x=self.r_xrange, axis=-1))
|
||||||
ll -= jnp.log(pnorm)
|
ll -= jnp.log(pnorm)
|
||||||
|
|
||||||
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
return ll0 + jnp.sum(logsumexp(ll, axis=0)) + self.norm
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"No distance marginalisation not implemented yet.")
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Combining several catalogues #
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
def PV_validation_model(models, distmod_hyperparams_per_model,
|
def PV_validation_model(models, distmod_hyperparams_per_model,
|
||||||
|
|
Loading…
Reference in a new issue