Add SN calibration model

This commit is contained in:
rstiskalek 2024-03-10 11:19:13 +00:00
parent b503a6f003
commit 088b15429b

View file

@ -538,13 +538,15 @@ class SD_PV_validation_model:
def __init__(self, los_density, los_velocity, RA, dec, z_obs, def __init__(self, los_density, los_velocity, RA, dec, z_obs,
r_hMpc, e_r_hMpc, r_xrange, Omega_m): r_hMpc, e_r_hMpc, r_xrange, Omega_m):
# Convert everything to JAX arrays.
dt = jnp.float32 dt = jnp.float32
# Convert everything to JAX arrays.
self._los_density = jnp.asarray(los_density, dtype=dt) self._los_density = jnp.asarray(los_density, dtype=dt)
self._los_velocity = jnp.asarray(los_velocity, dtype=dt) self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt) self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt) self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
self._z_obs = jnp.asarray(z_obs, dtype=dt) self._z_obs = jnp.asarray(z_obs, dtype=dt)
self._r_hMpc = jnp.asarray(r_hMpc, dtype=dt) self._r_hMpc = jnp.asarray(r_hMpc, dtype=dt)
self._e_rhMpc = jnp.asarray(e_r_hMpc, dtype=dt) self._e_rhMpc = jnp.asarray(e_r_hMpc, dtype=dt)
@ -564,37 +566,39 @@ class SD_PV_validation_model:
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
# Vext_x, Vext_y, Vext_z: external velocity components # Distribution of external velocity components
self._dist_Vext = dist.Uniform(-1000, 1000) self._Vext = dist.Uniform(-1000, 1000)
# We want sigma_v to be 150 +- 100 km / s (lognormal) # Distribution of density, velocity and location bias parameters
self._dist_sigma_v = dist.LogNormal( self._alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5)) # noqa
*lognorm_mean_std_to_loc_scale(150, 100)) self._beta = dist.Normal(1., 0.5)
# Density power-law bias self._h = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5))
self._dist_alpha = dist.LogNormal( # Distribution of velocity uncertainty sigma_v
*lognorm_mean_std_to_loc_scale(1.0, 0.5)) self._sv = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100))
# Velocity bias
self._dist_beta = dist.Normal(1., 0.5)
def __call__(self): def __call__(self, sample_alpha=False, scale_distance=False):
""" """
The simple distance NumPyro PV validation model. Samples the following The simple distance NumPyro PV validation model.
parameters:
- `Vext_x`, `Vext_y`, `Vext_z`: external velocity components Parameters
- `alpha`: density bias parameter ----------
- `beta`: velocity bias parameter sample_alpha : bool, optional
- `sigma_v`: velocity uncertainty Whether to sample the density bias parameter `alpha`, otherwise
it is fixed to 1.
scale_distance : bool, optional
Whether to scale the distance by `h`, otherwise it is fixed to 1.
""" """
Vx = numpyro.sample("Vext_x", self._dist_Vext) Vx = numpyro.sample("Vext_x", self._Vext)
Vy = numpyro.sample("Vext_y", self._dist_Vext) Vy = numpyro.sample("Vext_y", self._Vext)
Vz = numpyro.sample("Vext_z", self._dist_Vext) Vz = numpyro.sample("Vext_z", self._Vext)
alpha = numpyro.sample("alpha", self._dist_alpha) alpha = numpyro.sample("alpha", self._alpha) if sample_alpha else 1.0
beta = numpyro.sample("beta", self._dist_beta) beta = numpyro.sample("beta", self._beta)
sigma_v = numpyro.sample("sigma_v", self._dist_sigma_v) h = numpyro.sample("h", self._h) if scale_distance else 1.0
sigma_v = numpyro.sample("sigma_v", self._sv)
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec) Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
# Calculate p(r) and multiply it by the galaxy bias # Calculate p(r) and multiply it by the galaxy bias
ptilde = self._vmap_ptilde_wo_bias(self._r_hMpc, self._e_rhMpc) ptilde = self._vmap_ptilde_wo_bias(h * self._r_hMpc, h * self._e_rhMpc)
ptilde *= self._los_density**alpha ptilde *= self._los_density**alpha
# Normalization of p(r) # Normalization of p(r)
@ -608,50 +612,138 @@ class SD_PV_validation_model:
numpyro.factor("ll", ll) numpyro.factor("ll", ll)
# def SN_PV_wcal_validation_model(los_overdensity=None, los_velocity=None, class SN_PV_validation_model:
# RA=None, dec=None, z_CMB=None, """
# mB=None, x1=None, c=None, Supernova peculiar velocity (PV) validation model that includes the
# e_mB=None, e_x1=None, e_c=None, calibration of the SALT2 light curve parameters.
# mu_xrange=None, r_xrange=None,
# norm_r2_xrange=None, Omega_m=None, dr=None): Parameters
# """ ----------
# Pass los_density : 2-dimensional array of shape (n_objects, n_steps)
# """ LOS density field.
# Vx = numpyro.sample("Vext_x", dist.Uniform(-1000, 1000)) los_velocity : 3-dimensional array of shape (n_objects, n_steps)
# Vy = numpyro.sample("Vext_y", dist.Uniform(-1000, 1000)) LOS radial velocity field.
# Vz = numpyro.sample("Vext_z", dist.Uniform(-1000, 1000)) RA, dec : 1-dimensional arrays of shape (n_objects)
# beta = numpyro.sample("beta", dist.Uniform(-10, 10)) Right ascension and declination in degrees.
# z_obs : 1-dimensional array of shape (n_objects)
# # TODO: Later sample these as well. Observed redshifts.
# e_mu_intrinsic = 0.064 mB, x1, c : 1-dimensional arrays of shape (n_objects)
# alpha_cal = 0.135 SALT2 light curve parameters.
# beta_cal = 2.9 e_mB, e_x1, e_c : 1-dimensional arrays of shape (n_objects)
# mag_cal = -18.555 Errors on the SALT2 light curve parameters.
# sigma_v = 112 r_xrange : 1-dimensional array
# Radial distances where the field was interpolated for each object.
# # TODO: Check these for fiducial values. Omega_m : float
# mu = mB - mag_cal + alpha_cal * x1 - beta_cal * c Matter density parameter.
# squared_e_mu = e_mB**2 + alpha_cal**2 * e_x1**2 + beta_cal**2 * e_c**2 """
#
# squared_e_mu += e_mu_intrinsic**2 def __init__(self, los_density, los_velocity, RA, dec, z_obs,
# ll = 0. mB, x1, c, e_mB, e_x1, e_c, r_xrange, Omega_m):
# for i in range(len(los_overdensity)): dt = jnp.float32
# # Project the external velocity for this galaxy. # Convert everything to JAX arrays.
# Vext_rad = project_Vext(Vx, Vy, Vz, RA[i], dec[i]) self._los_density = jnp.asarray(los_density, dtype=dt)
# self._los_velocity = jnp.asarray(los_velocity, dtype=dt)
# dmu = mu_xrange - mu[i]
# ptilde = norm_r2_xrange * jnp.exp(-0.5 * dmu**2 / squared_e_mu[i]) self._RA = jnp.asarray(np.deg2rad(RA), dtype=dt)
# # TODO: Add some bias self._dec = jnp.asarray(np.deg2rad(dec), dtype=dt)
# ptilde *= (1 + los_overdensity[i]) self._z_obs = jnp.asarray(z_obs, dtype=dt)
#
# zobs_pred = predict_zobs(r_xrange, beta, Vext_rad, los_velocity[i], self._mB = jnp.asarray(mB, dtype=dt)
# Omega_m) self._x1 = jnp.asarray(x1, dtype=dt)
# self._c = jnp.asarray(c, dtype=dt)
# dczobs = SPEED_OF_LIGHT * (z_CMB[i] - zobs_pred) self._e2_mB = jnp.asarray(e_mB**2, dtype=dt)
# self._e2_x1 = jnp.asarray(e_x1**2, dtype=dt)
# ll_zobs = jnp.exp(-0.5 * (dczobs / sigma_v)**2) / sigma_v self._e2_c = jnp.asarray(e_c**2, dtype=dt)
#
# ll += jnp.log(simps(ptilde * ll_zobs, dr)) # Get radius squared
# ll -= jnp.log(simps(ptilde, dr)) r2_xrange = r_xrange**2
# r2_xrange /= r2_xrange.mean()
# numpyro.factor("ll", ll) mu_xrange = dist2distmodulus(r_xrange, Omega_m)
# Get the stepsize, we need it to be constant for Simpson's rule.
dr = np.diff(r_xrange)
if not np.all(np.isclose(dr, dr[0], atol=1e-5)):
raise ValueError("The radial step size must be constant.")
dr = dr[0]
# Get the various vmapped functions
self._vmap_ptilde_wo_bias = vmap(lambda mu, err: calculate_ptilde_wo_bias(mu_xrange, mu, err, r2_xrange)) # noqa
self._vmap_simps = vmap(lambda y: simps(y, dr))
self._vmap_zobs = vmap(lambda beta, Vr, vpec_rad: predict_zobs(r_xrange, beta, Vr, vpec_rad, Omega_m), in_axes=(None, 0, 0)) # noqa
self._vmap_ll_zobs = vmap(lambda zobs, zobs_pred, sigma_v: calculate_ll_zobs(zobs, zobs_pred, sigma_v), in_axes=(0, 0, None)) # noqa
# Distribution of external velocity components
self._dist_Vext = dist.Uniform(-1000, 1000)
# Distribution of velocity and density bias parameters
self._dist_alpha = dist.LogNormal(*lognorm_mean_std_to_loc_scale(1.0, 0.5)) # noqa
self._dist_beta = dist.Normal(1., 0.5)
# Distribution of velocity uncertainty
self._dist_sigma_v = dist.LogNormal(*lognorm_mean_std_to_loc_scale(150, 100)) # noqa
# Distribution of light curve calibration parameters
self._dist_mag_cal = dist.Normal(-18.25, 1.0)
self._dist_alpha_cal = dist.Normal(0.1, 0.5)
self._dist_beta_cal = dist.Normal(3.0, 1.0)
self._dist_e_mu = dist.LogNormal(*lognorm_mean_std_to_loc_scale(0.1, 0.05)) # noqa
def __call__(self, sample_alpha=False, fix_calibration=False):
"""
The supernova NumPyro PV validation model with SALT2 calibration.
Parameters
----------
sample_alpha : bool, optional
Whether to sample the density bias parameter `alpha`, otherwise
it is fixed to 1.
"""
Vx = numpyro.sample("Vext_x", self._dist_Vext)
Vy = numpyro.sample("Vext_y", self._dist_Vext)
Vz = numpyro.sample("Vext_z", self._dist_Vext)
if sample_alpha:
alpha = numpyro.sample("alpha", self._dist_alpha)
else:
alpha = 1.0
beta = numpyro.sample("beta", self._dist_beta)
sigma_v = numpyro.sample("sigma_v", self._dist_sigma_v)
if fix_calibration == "Foundation":
# Foundation inverse best parameters
e_mu_intrinsic = 0.064
alpha_cal = 0.135
beta_cal = 2.9
sigma_v = 140
mag_cal = -18.555
elif fix_calibration == "LOSS":
# LOSS inverse best parameters
e_mu_intrinsic = 0.123
alpha_cal = 0.123
beta_cal = 3.52
mag_cal = -18.195
sigma_v = 140
else:
e_mu_intrinsic = numpyro.sample("e_mu_intrinsic", self._dist_e_mu)
mag_cal = numpyro.sample("mag_cal", self._dist_mag_cal)
alpha_cal = numpyro.sample("alpha_cal", self._dist_alpha_cal)
beta_cal = numpyro.sample("beta_cal", self._dist_beta_cal)
Vext_rad = project_Vext(Vx, Vy, Vz, self._RA, self._dec)
mu = self._mB - mag_cal + alpha_cal * self._x1 - beta_cal * self._c
squared_e_mu = (self._e2_mB
+ alpha_cal**2 * self._e2_x1
+ beta_cal**2 * self._e2_c)
squared_e_mu += e_mu_intrinsic**2
# Calculate p(r) and multiply it by the galaxy bias
ptilde = self._vmap_ptilde_wo_bias(mu, squared_e_mu**0.5)
ptilde *= self._los_density**alpha
# Normalization of p(r)
pnorm = self._vmap_simps(ptilde)
# Calculate p(z_obs) and multiply it by p(r)
zobs_pred = self._vmap_zobs(beta, Vext_rad, self._los_velocity)
ptilde *= self._vmap_ll_zobs(self._z_obs, zobs_pred, sigma_v)
ll = jnp.sum(jnp.log(self._vmap_simps(ptilde) / pnorm))
numpyro.factor("ll", ll)