mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 22:38:03 +00:00
Add SN calibration model
This commit is contained in:
parent
b503a6f003
commit
088b15429b
1 changed files with 164 additions and 72 deletions
|
@ -538,13 +538,15 @@ class SD_PV_validation_model:
|
||||||
|
|
||||||
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
||||||
r_hMpc, e_r_hMpc, r_xrange, Omega_m):
|
r_hMpc, e_r_hMpc, r_xrange, Omega_m):
|
||||||
# Convert everything to JAX arrays.
|
|
||||||
dt = jnp.float32
|
dt = jnp.float32
|
||||||
|
# Convert everything to JAX arrays.
|
||||||
self._los_density = jnp.asarray(los_density, dtype=dt)
|
self._los_density = jnp.asarray(los_density, dtype=dt)
|
||||||
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
||||||
|
|
||||||
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
||||||
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
||||||
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
||||||
|
|
||||||
self._r_hMpc = jnp.asarray(r_hMpc, dtype=dt)
|
self._r_hMpc = jnp.asarray(r_hMpc, dtype=dt)
|
||||||
self._e_rhMpc = jnp.asarray(e_r_hMpc, dtype=dt)
|
self._e_rhMpc = jnp.asarray(e_r_hMpc, dtype=dt)
|
||||||
|
|
||||||
|
@ -564,37 +566,39 @@ class SD_PV_validation_model:
|
||||||
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
|
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
|
||||||
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
|
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
|
||||||
|
|
||||||
# Vext_x, Vext_y, Vext_z: external velocity components
|
# Distribution of external velocity components
|
||||||
self._dist_Vext = dist.Uniform(-1000, 1000)
|
self._Vext = dist.Uniform(-1000, 1000)
|
||||||
# We want sigma_v to be 150 +- 100 km / s (lognormal)
|
# Distribution of density, velocity and location bias parameters
|
||||||
self._dist_sigma_v = dist.LogNormal(
|
self._alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5)) # noqa
|
||||||
*lognorm_mean_std_to_loc_scale(150, 100))
|
self._beta = dist.Normal(1., 0.5)
|
||||||
# Density power-law bias
|
self._h = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5))
|
||||||
self._dist_alpha = dist.LogNormal(
|
# Distribution of velocity uncertainty sigma_v
|
||||||
*lognorm_mean_std_to_loc_scale(1.0, 0.5))
|
self._sv = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100))
|
||||||
# Velocity bias
|
|
||||||
self._dist_beta = dist.Normal(1., 0.5)
|
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self, sample_alpha=False, scale_distance=False):
|
||||||
"""
|
"""
|
||||||
The simple distance NumPyro PV validation model. Samples the following
|
The simple distance NumPyro PV validation model.
|
||||||
parameters:
|
|
||||||
- `Vext_x`, `Vext_y`, `Vext_z`: external velocity components
|
Parameters
|
||||||
- `alpha`: density bias parameter
|
----------
|
||||||
- `beta`: velocity bias parameter
|
sample_alpha : bool, optional
|
||||||
- `sigma_v`: velocity uncertainty
|
Whether to sample the density bias parameter `alpha`, otherwise
|
||||||
|
it is fixed to 1.
|
||||||
|
scale_distance : bool, optional
|
||||||
|
Whether to scale the distance by `h`, otherwise it is fixed to 1.
|
||||||
"""
|
"""
|
||||||
Vx = numpyro.sample("Vext_x", self._dist_Vext)
|
Vx = numpyro.sample("Vext_x", self._Vext)
|
||||||
Vy = numpyro.sample("Vext_y", self._dist_Vext)
|
Vy = numpyro.sample("Vext_y", self._Vext)
|
||||||
Vz = numpyro.sample("Vext_z", self._dist_Vext)
|
Vz = numpyro.sample("Vext_z", self._Vext)
|
||||||
alpha = numpyro.sample("alpha", self._dist_alpha)
|
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
|
||||||
beta = numpyro.sample("beta", self._dist_beta)
|
beta = numpyro.sample("beta", self._beta)
|
||||||
sigma_v = numpyro.sample("sigma_v", self._dist_sigma_v)
|
h = numpyro.sample("h", self._h) if scale_distance else 1.0
|
||||||
|
sigma_v = numpyro.sample("sigma_v", self._sv)
|
||||||
|
|
||||||
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
||||||
|
|
||||||
# Calculate p(r) and multiply it by the galaxy bias
|
# Calculate p(r) and multiply it by the galaxy bias
|
||||||
ptilde = self._vmap_ptilde_wo_bias(self._r_hMpc, self._e_rhMpc)
|
ptilde = self._vmap_ptilde_wo_bias(h * self._r_hMpc, h * self._e_rhMpc)
|
||||||
ptilde *= self._los_density**alpha
|
ptilde *= self._los_density**alpha
|
||||||
|
|
||||||
# Normalization of p(r)
|
# Normalization of p(r)
|
||||||
|
@ -608,50 +612,138 @@ class SD_PV_validation_model:
|
||||||
numpyro.factor("ll", ll)
|
numpyro.factor("ll", ll)
|
||||||
|
|
||||||
|
|
||||||
# def SN_PV_wcal_validation_model(los_overdensity=None, los_velocity=None,
|
class SN_PV_validation_model:
|
||||||
# RA=None, dec=None, z_CMB=None,
|
"""
|
||||||
# mB=None, x1=None, c=None,
|
Supernova peculiar velocity (PV) validation model that includes the
|
||||||
# e_mB=None, e_x1=None, e_c=None,
|
calibration of the SALT2 light curve parameters.
|
||||||
# mu_xrange=None, r_xrange=None,
|
|
||||||
# norm_r2_xrange=None, Omega_m=None, dr=None):
|
Parameters
|
||||||
# """
|
----------
|
||||||
# Pass
|
los_density : 2-dimensional array of shape (n_objects, n_steps)
|
||||||
# """
|
LOS density field.
|
||||||
# Vx = numpyro.sample("Vext_x", dist.Uniform(-1000, 1000))
|
los_velocity : 3-dimensional array of shape (n_objects, n_steps)
|
||||||
# Vy = numpyro.sample("Vext_y", dist.Uniform(-1000, 1000))
|
LOS radial velocity field.
|
||||||
# Vz = numpyro.sample("Vext_z", dist.Uniform(-1000, 1000))
|
RA, dec : 1-dimensional arrays of shape (n_objects)
|
||||||
# beta = numpyro.sample("beta", dist.Uniform(-10, 10))
|
Right ascension and declination in degrees.
|
||||||
#
|
z_obs : 1-dimensional array of shape (n_objects)
|
||||||
# # TODO: Later sample these as well.
|
Observed redshifts.
|
||||||
# e_mu_intrinsic = 0.064
|
mB, x1, c : 1-dimensional arrays of shape (n_objects)
|
||||||
# alpha_cal = 0.135
|
SALT2 light curve parameters.
|
||||||
# beta_cal = 2.9
|
e_mB, e_x1, e_c : 1-dimensional arrays of shape (n_objects)
|
||||||
# mag_cal = -18.555
|
Errors on the SALT2 light curve parameters.
|
||||||
# sigma_v = 112
|
r_xrange : 1-dimensional array
|
||||||
#
|
Radial distances where the field was interpolated for each object.
|
||||||
# # TODO: Check these for fiducial values.
|
Omega_m : float
|
||||||
# mu = mB - mag_cal + alpha_cal * x1 - beta_cal * c
|
Matter density parameter.
|
||||||
# squared_e_mu = e_mB**2 + alpha_cal**2 * e_x1**2 + beta_cal**2 * e_c**2
|
"""
|
||||||
#
|
|
||||||
# squared_e_mu += e_mu_intrinsic**2
|
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
||||||
# ll = 0.
|
mB, x1, c, e_mB, e_x1, e_c, r_xrange, Omega_m):
|
||||||
# for i in range(len(los_overdensity)):
|
dt = jnp.float32
|
||||||
# # Project the external velocity for this galaxy.
|
# Convert everything to JAX arrays.
|
||||||
# Vext_rad = project_Vext(Vx, Vy, Vz, RA[i], dec[i])
|
self._los_density = jnp.asarray(los_density, dtype=dt)
|
||||||
#
|
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
||||||
# dmu = mu_xrange - mu[i]
|
|
||||||
# ptilde = norm_r2_xrange * jnp.exp(-0.5 * dmu**2 / squared_e_mu[i])
|
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
||||||
# # TODO: Add some bias
|
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
||||||
# ptilde *= (1 + los_overdensity[i])
|
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
||||||
#
|
|
||||||
# zobs_pred = predict_zobs(r_xrange, beta, Vext_rad, los_velocity[i],
|
self._mB = jnp.asarray(mB, dtype=dt)
|
||||||
# Omega_m)
|
self._x1 = jnp.asarray(x1, dtype=dt)
|
||||||
#
|
self._c = jnp.asarray(c, dtype=dt)
|
||||||
# dczobs = SPEED_OF_LIGHT * (z_CMB[i] - zobs_pred)
|
self._e2_mB = jnp.asarray(e_mB**2, dtype=dt)
|
||||||
#
|
self._e2_x1 = jnp.asarray(e_x1**2, dtype=dt)
|
||||||
# ll_zobs = jnp.exp(-0.5 * (dczobs / sigma_v)**2) / sigma_v
|
self._e2_c = jnp.asarray(e_c**2, dtype=dt)
|
||||||
#
|
|
||||||
# ll += jnp.log(simps(ptilde * ll_zobs, dr))
|
# Get radius squared
|
||||||
# ll -= jnp.log(simps(ptilde, dr))
|
r2_xrange = r_xrange**2
|
||||||
#
|
r2_xrange /= r2_xrange.mean()
|
||||||
# numpyro.factor("ll", ll)
|
mu_xrange = dist2distmodulus(r_xrange, Omega_m)
|
||||||
|
|
||||||
|
# Get the stepsize, we need it to be constant for Simpson's rule.
|
||||||
|
dr = np.diff(r_xrange)
|
||||||
|
if not np.all(np.isclose(dr, dr[0], atol=1e-5)):
|
||||||
|
raise ValueError("The radial step size must be constant.")
|
||||||
|
dr = dr[0]
|
||||||
|
|
||||||
|
# Get the various vmapped functions
|
||||||
|
self._vmap_ptilde_wo_bias = vmap(lambda mu, err: calculate_ptilde_wo_bias(mu_xrange, mu, err, r2_xrange)) # noqa
|
||||||
|
self._vmap_simps = vmap(lambda y: simps(y, dr))
|
||||||
|
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
|
||||||
|
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
|
||||||
|
|
||||||
|
# Distribution of external velocity components
|
||||||
|
self._dist_Vext = dist.Uniform(-1000, 1000)
|
||||||
|
# Distribution of velocity and density bias parameters
|
||||||
|
self._dist_alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5)) # noqa
|
||||||
|
self._dist_beta = dist.Normal(1., 0.5)
|
||||||
|
# Distribution of velocity uncertainty
|
||||||
|
self._dist_sigma_v = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100)) # noqa
|
||||||
|
|
||||||
|
# Distribution of light curve calibration parameters
|
||||||
|
self._dist_mag_cal = dist.Normal(-18.25, 1.0)
|
||||||
|
self._dist_alpha_cal = dist.Normal(0.1, 0.5)
|
||||||
|
self._dist_beta_cal = dist.Normal(3.0, 1.0)
|
||||||
|
self._dist_e_mu = dist.LogNormal(*lognorm_mean_std_to_loc_scale(0.1, 0.05)) # noqa
|
||||||
|
|
||||||
|
def __call__(self, sample_alpha=False, fix_calibration=False):
|
||||||
|
"""
|
||||||
|
The supernova NumPyro PV validation model with SALT2 calibration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sample_alpha : bool, optional
|
||||||
|
Whether to sample the density bias parameter `alpha`, otherwise
|
||||||
|
it is fixed to 1.
|
||||||
|
"""
|
||||||
|
Vx = numpyro.sample("Vext_x", self._dist_Vext)
|
||||||
|
Vy = numpyro.sample("Vext_y", self._dist_Vext)
|
||||||
|
Vz = numpyro.sample("Vext_z", self._dist_Vext)
|
||||||
|
if sample_alpha:
|
||||||
|
alpha = numpyro.sample("alpha", self._dist_alpha)
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
beta = numpyro.sample("beta", self._dist_beta)
|
||||||
|
sigma_v = numpyro.sample("sigma_v", self._dist_sigma_v)
|
||||||
|
|
||||||
|
if fix_calibration == "Foundation":
|
||||||
|
# Foundation inverse best parameters
|
||||||
|
e_mu_intrinsic = 0.064
|
||||||
|
alpha_cal = 0.135
|
||||||
|
beta_cal = 2.9
|
||||||
|
sigma_v = 140
|
||||||
|
mag_cal = -18.555
|
||||||
|
elif fix_calibration == "LOSS":
|
||||||
|
# LOSS inverse best parameters
|
||||||
|
e_mu_intrinsic = 0.123
|
||||||
|
alpha_cal = 0.123
|
||||||
|
beta_cal = 3.52
|
||||||
|
mag_cal = -18.195
|
||||||
|
sigma_v = 140
|
||||||
|
else:
|
||||||
|
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._dist_e_mu)
|
||||||
|
mag_cal = numpyro.sample("mag_cal", self._dist_mag_cal)
|
||||||
|
alpha_cal = numpyro.sample("alpha_cal", self._dist_alpha_cal)
|
||||||
|
beta_cal = numpyro.sample("beta_cal", self._dist_beta_cal)
|
||||||
|
|
||||||
|
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
||||||
|
|
||||||
|
mu = self._mB - mag_cal + alpha_cal * self._x1 - beta_cal * self._c
|
||||||
|
squared_e_mu = (self._e2_mB
|
||||||
|
+ alpha_cal**2 * self._e2_x1
|
||||||
|
+ beta_cal**2 * self._e2_c)
|
||||||
|
squared_e_mu += e_mu_intrinsic**2
|
||||||
|
|
||||||
|
# Calculate p(r) and multiply it by the galaxy bias
|
||||||
|
ptilde = self._vmap_ptilde_wo_bias(mu, squared_e_mu**0.5)
|
||||||
|
ptilde *= self._los_density**alpha
|
||||||
|
|
||||||
|
# Normalization of p(r)
|
||||||
|
pnorm = self._vmap_simps(ptilde)
|
||||||
|
|
||||||
|
# Calculate p(z_obs) and multiply it by p(r)
|
||||||
|
zobs_pred = self._vmap_zobs(beta, Vext_rad, self._los_velocity)
|
||||||
|
ptilde *= self._vmap_ll_zobs(self._z_obs, zobs_pred, sigma_v)
|
||||||
|
|
||||||
|
ll = jnp.sum(jnp.log(self._vmap_simps(ptilde) / pnorm))
|
||||||
|
numpyro.factor("ll", ll)
|
||||||
|
|
Loading…
Reference in a new issue