mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 17:08:03 +00:00
Add SN calibration model
This commit is contained in:
parent
b503a6f003
commit
088b15429b
1 changed files with 164 additions and 72 deletions
|
@ -538,13 +538,15 @@ class SD_PV_validation_model:
|
|||
|
||||
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
||||
r_hMpc, e_r_hMpc, r_xrange, Omega_m):
|
||||
# Convert everything to JAX arrays.
|
||||
dt = jnp.float32
|
||||
# Convert everything to JAX arrays.
|
||||
self._los_density = jnp.asarray(los_density, dtype=dt)
|
||||
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
||||
|
||||
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
||||
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
||||
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
||||
|
||||
self._r_hMpc = jnp.asarray(r_hMpc, dtype=dt)
|
||||
self._e_rhMpc = jnp.asarray(e_r_hMpc, dtype=dt)
|
||||
|
||||
|
@ -564,37 +566,39 @@ class SD_PV_validation_model:
|
|||
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
|
||||
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
|
||||
|
||||
# Vext_x, Vext_y, Vext_z: external velocity components
|
||||
self._dist_Vext = dist.Uniform(-1000, 1000)
|
||||
# We want sigma_v to be 150 +- 100 km / s (lognormal)
|
||||
self._dist_sigma_v = dist.LogNormal(
|
||||
*lognorm_mean_std_to_loc_scale(150, 100))
|
||||
# Density power-law bias
|
||||
self._dist_alpha = dist.LogNormal(
|
||||
*lognorm_mean_std_to_loc_scale(1.0, 0.5))
|
||||
# Velocity bias
|
||||
self._dist_beta = dist.Normal(1., 0.5)
|
||||
# Distribution of external velocity components
|
||||
self._Vext = dist.Uniform(-1000, 1000)
|
||||
# Distribution of density, velocity and location bias parameters
|
||||
self._alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5)) # noqa
|
||||
self._beta = dist.Normal(1., 0.5)
|
||||
self._h = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5))
|
||||
# Distribution of velocity uncertainty sigma_v
|
||||
self._sv = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100))
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self, sample_alpha=False, scale_distance=False):
|
||||
"""
|
||||
The simple distance NumPyro PV validation model. Samples the following
|
||||
parameters:
|
||||
- `Vext_x`, `Vext_y`, `Vext_z`: external velocity components
|
||||
- `alpha`: density bias parameter
|
||||
- `beta`: velocity bias parameter
|
||||
- `sigma_v`: velocity uncertainty
|
||||
The simple distance NumPyro PV validation model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sample_alpha : bool, optional
|
||||
Whether to sample the density bias parameter `alpha`, otherwise
|
||||
it is fixed to 1.
|
||||
scale_distance : bool, optional
|
||||
Whether to scale the distance by `h`, otherwise it is fixed to 1.
|
||||
"""
|
||||
Vx = numpyro.sample("Vext_x", self._dist_Vext)
|
||||
Vy = numpyro.sample("Vext_y", self._dist_Vext)
|
||||
Vz = numpyro.sample("Vext_z", self._dist_Vext)
|
||||
alpha = numpyro.sample("alpha", self._dist_alpha)
|
||||
beta = numpyro.sample("beta", self._dist_beta)
|
||||
sigma_v = numpyro.sample("sigma_v", self._dist_sigma_v)
|
||||
Vx = numpyro.sample("Vext_x", self._Vext)
|
||||
Vy = numpyro.sample("Vext_y", self._Vext)
|
||||
Vz = numpyro.sample("Vext_z", self._Vext)
|
||||
alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
|
||||
beta = numpyro.sample("beta", self._beta)
|
||||
h = numpyro.sample("h", self._h) if scale_distance else 1.0
|
||||
sigma_v = numpyro.sample("sigma_v", self._sv)
|
||||
|
||||
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
||||
|
||||
# Calculate p(r) and multiply it by the galaxy bias
|
||||
ptilde = self._vmap_ptilde_wo_bias(self._r_hMpc, self._e_rhMpc)
|
||||
ptilde = self._vmap_ptilde_wo_bias(h * self._r_hMpc, h * self._e_rhMpc)
|
||||
ptilde *= self._los_density**alpha
|
||||
|
||||
# Normalization of p(r)
|
||||
|
@ -608,50 +612,138 @@ class SD_PV_validation_model:
|
|||
numpyro.factor("ll", ll)
|
||||
|
||||
|
||||
# def SN_PV_wcal_validation_model(los_overdensity=None, los_velocity=None,
|
||||
# RA=None, dec=None, z_CMB=None,
|
||||
# mB=None, x1=None, c=None,
|
||||
# e_mB=None, e_x1=None, e_c=None,
|
||||
# mu_xrange=None, r_xrange=None,
|
||||
# norm_r2_xrange=None, Omega_m=None, dr=None):
|
||||
# """
|
||||
# Pass
|
||||
# """
|
||||
# Vx = numpyro.sample("Vext_x", dist.Uniform(-1000, 1000))
|
||||
# Vy = numpyro.sample("Vext_y", dist.Uniform(-1000, 1000))
|
||||
# Vz = numpyro.sample("Vext_z", dist.Uniform(-1000, 1000))
|
||||
# beta = numpyro.sample("beta", dist.Uniform(-10, 10))
|
||||
#
|
||||
# # TODO: Later sample these as well.
|
||||
# e_mu_intrinsic = 0.064
|
||||
# alpha_cal = 0.135
|
||||
# beta_cal = 2.9
|
||||
# mag_cal = -18.555
|
||||
# sigma_v = 112
|
||||
#
|
||||
# # TODO: Check these for fiducial values.
|
||||
# mu = mB - mag_cal + alpha_cal * x1 - beta_cal * c
|
||||
# squared_e_mu = e_mB**2 + alpha_cal**2 * e_x1**2 + beta_cal**2 * e_c**2
|
||||
#
|
||||
# squared_e_mu += e_mu_intrinsic**2
|
||||
# ll = 0.
|
||||
# for i in range(len(los_overdensity)):
|
||||
# # Project the external velocity for this galaxy.
|
||||
# Vext_rad = project_Vext(Vx, Vy, Vz, RA[i], dec[i])
|
||||
#
|
||||
# dmu = mu_xrange - mu[i]
|
||||
# ptilde = norm_r2_xrange * jnp.exp(-0.5 * dmu**2 / squared_e_mu[i])
|
||||
# # TODO: Add some bias
|
||||
# ptilde *= (1 + los_overdensity[i])
|
||||
#
|
||||
# zobs_pred = predict_zobs(r_xrange, beta, Vext_rad, los_velocity[i],
|
||||
# Omega_m)
|
||||
#
|
||||
# dczobs = SPEED_OF_LIGHT * (z_CMB[i] - zobs_pred)
|
||||
#
|
||||
# ll_zobs = jnp.exp(-0.5 * (dczobs / sigma_v)**2) / sigma_v
|
||||
#
|
||||
# ll += jnp.log(simps(ptilde * ll_zobs, dr))
|
||||
# ll -= jnp.log(simps(ptilde, dr))
|
||||
#
|
||||
# numpyro.factor("ll", ll)
|
||||
class SN_PV_validation_model:
|
||||
"""
|
||||
Supernova peculiar velocity (PV) validation model that includes the
|
||||
calibration of the SALT2 light curve parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
los_density : 2-dimensional array of shape (n_objects, n_steps)
|
||||
LOS density field.
|
||||
los_velocity : 3-dimensional array of shape (n_objects, n_steps)
|
||||
LOS radial velocity field.
|
||||
RA, dec : 1-dimensional arrays of shape (n_objects)
|
||||
Right ascension and declination in degrees.
|
||||
z_obs : 1-dimensional array of shape (n_objects)
|
||||
Observed redshifts.
|
||||
mB, x1, c : 1-dimensional arrays of shape (n_objects)
|
||||
SALT2 light curve parameters.
|
||||
e_mB, e_x1, e_c : 1-dimensional arrays of shape (n_objects)
|
||||
Errors on the SALT2 light curve parameters.
|
||||
r_xrange : 1-dimensional array
|
||||
Radial distances where the field was interpolated for each object.
|
||||
Omega_m : float
|
||||
Matter density parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, los_density, los_velocity, RA, dec, z_obs,
|
||||
mB, x1, c, e_mB, e_x1, e_c, r_xrange, Omega_m):
|
||||
dt = jnp.float32
|
||||
# Convert everything to JAX arrays.
|
||||
self._los_density = jnp.asarray(los_density, dtype=dt)
|
||||
self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
|
||||
|
||||
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
|
||||
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
|
||||
self._z_obs = jnp.asarray(z_obs, dtype=dt)
|
||||
|
||||
self._mB = jnp.asarray(mB, dtype=dt)
|
||||
self._x1 = jnp.asarray(x1, dtype=dt)
|
||||
self._c = jnp.asarray(c, dtype=dt)
|
||||
self._e2_mB = jnp.asarray(e_mB**2, dtype=dt)
|
||||
self._e2_x1 = jnp.asarray(e_x1**2, dtype=dt)
|
||||
self._e2_c = jnp.asarray(e_c**2, dtype=dt)
|
||||
|
||||
# Get radius squared
|
||||
r2_xrange = r_xrange**2
|
||||
r2_xrange /= r2_xrange.mean()
|
||||
mu_xrange = dist2distmodulus(r_xrange, Omega_m)
|
||||
|
||||
# Get the stepsize, we need it to be constant for Simpson's rule.
|
||||
dr = np.diff(r_xrange)
|
||||
if not np.all(np.isclose(dr, dr[0], atol=1e-5)):
|
||||
raise ValueError("The radial step size must be constant.")
|
||||
dr = dr[0]
|
||||
|
||||
# Get the various vmapped functions
|
||||
self._vmap_ptilde_wo_bias = vmap(lambda mu, err: calculate_ptilde_wo_bias(mu_xrange, mu, err, r2_xrange)) # noqa
|
||||
self._vmap_simps = vmap(lambda y: simps(y, dr))
|
||||
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
|
||||
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
|
||||
|
||||
# Distribution of external velocity components
|
||||
self._dist_Vext = dist.Uniform(-1000, 1000)
|
||||
# Distribution of velocity and density bias parameters
|
||||
self._dist_alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5)) # noqa
|
||||
self._dist_beta = dist.Normal(1., 0.5)
|
||||
# Distribution of velocity uncertainty
|
||||
self._dist_sigma_v = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100)) # noqa
|
||||
|
||||
# Distribution of light curve calibration parameters
|
||||
self._dist_mag_cal = dist.Normal(-18.25, 1.0)
|
||||
self._dist_alpha_cal = dist.Normal(0.1, 0.5)
|
||||
self._dist_beta_cal = dist.Normal(3.0, 1.0)
|
||||
self._dist_e_mu = dist.LogNormal(*lognorm_mean_std_to_loc_scale(0.1, 0.05)) # noqa
|
||||
|
||||
def __call__(self, sample_alpha=False, fix_calibration=False):
|
||||
"""
|
||||
The supernova NumPyro PV validation model with SALT2 calibration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sample_alpha : bool, optional
|
||||
Whether to sample the density bias parameter `alpha`, otherwise
|
||||
it is fixed to 1.
|
||||
"""
|
||||
Vx = numpyro.sample("Vext_x", self._dist_Vext)
|
||||
Vy = numpyro.sample("Vext_y", self._dist_Vext)
|
||||
Vz = numpyro.sample("Vext_z", self._dist_Vext)
|
||||
if sample_alpha:
|
||||
alpha = numpyro.sample("alpha", self._dist_alpha)
|
||||
else:
|
||||
alpha = 1.0
|
||||
beta = numpyro.sample("beta", self._dist_beta)
|
||||
sigma_v = numpyro.sample("sigma_v", self._dist_sigma_v)
|
||||
|
||||
if fix_calibration == "Foundation":
|
||||
# Foundation inverse best parameters
|
||||
e_mu_intrinsic = 0.064
|
||||
alpha_cal = 0.135
|
||||
beta_cal = 2.9
|
||||
sigma_v = 140
|
||||
mag_cal = -18.555
|
||||
elif fix_calibration == "LOSS":
|
||||
# LOSS inverse best parameters
|
||||
e_mu_intrinsic = 0.123
|
||||
alpha_cal = 0.123
|
||||
beta_cal = 3.52
|
||||
mag_cal = -18.195
|
||||
sigma_v = 140
|
||||
else:
|
||||
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._dist_e_mu)
|
||||
mag_cal = numpyro.sample("mag_cal", self._dist_mag_cal)
|
||||
alpha_cal = numpyro.sample("alpha_cal", self._dist_alpha_cal)
|
||||
beta_cal = numpyro.sample("beta_cal", self._dist_beta_cal)
|
||||
|
||||
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
|
||||
|
||||
mu = self._mB - mag_cal + alpha_cal * self._x1 - beta_cal * self._c
|
||||
squared_e_mu = (self._e2_mB
|
||||
+ alpha_cal**2 * self._e2_x1
|
||||
+ beta_cal**2 * self._e2_c)
|
||||
squared_e_mu += e_mu_intrinsic**2
|
||||
|
||||
# Calculate p(r) and multiply it by the galaxy bias
|
||||
ptilde = self._vmap_ptilde_wo_bias(mu, squared_e_mu**0.5)
|
||||
ptilde *= self._los_density**alpha
|
||||
|
||||
# Normalization of p(r)
|
||||
pnorm = self._vmap_simps(ptilde)
|
||||
|
||||
# Calculate p(z_obs) and multiply it by p(r)
|
||||
zobs_pred = self._vmap_zobs(beta, Vext_rad, self._los_velocity)
|
||||
ptilde *= self._vmap_ll_zobs(self._z_obs, zobs_pred, sigma_v)
|
||||
|
||||
ll = jnp.sum(jnp.log(self._vmap_simps(ptilde) / pnorm))
|
||||
numpyro.factor("ll", ll)
|
||||
|
|
Loading…
Reference in a new issue