Enable scalar or array values for measurement errors in TFR and SN tests

This commit is contained in:
Deaglan Bartlett 2025-02-09 12:00:35 +01:00
parent 4031daf5e2
commit 071beeae80
2 changed files with 57 additions and 42 deletions

View file

@ -333,16 +333,21 @@ def likelihood_stretch(stretch_true, stretch_obs, sigma_stretch):
Args:
- stretch_true (np.ndarray): True stretch of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch of the tracers (shape = (Nt,))
- sigma_stretch (float): Uncertainty on the stretch measurements
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
Nt = stretch_obs.shape[0]
norm = jnp.where(
jnp.ndim(sigma_stretch) == 0,
Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2),
jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2))
)
loglike = - (
0.5 * jnp.sum((stretch_obs - stretch_true) ** 2 / sigma_stretch ** 2)
+ Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2)
+ norm
)
return loglike
@ -355,16 +360,21 @@ def likelihood_c(c_true, c_obs, sigma_c):
Args:
- c_true (np.ndarray): True colours of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colours of the tracers (shape = (Nt,))
- sigma_c (float): Uncertainty on the colours measurements
- sigma_c (float or np.ndarray): Uncertainty on the colours measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
Nt = c_obs.shape[0]
norm = jnp.where(
jnp.ndim(sigma_c) == 0,
Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2),
jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2))
)
loglike = - (
0.5 * jnp.sum((c_obs - c_true) ** 2 / sigma_c ** 2)
+ Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2)
+ norm
)
return loglike
@ -377,16 +387,21 @@ def likelihood_m(m_true, m_obs, sigma_m):
Args:
- m_true (np.ndarray): True apparent magnitude of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitude of the tracers (shape = (Nt,))
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
Nt = m_obs.shape[0]
norm = jnp.where(
jnp.ndim(sigma_m) == 0,
Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2),
jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2))
)
loglike = - (
0.5 * jnp.sum((m_obs - m_true) ** 2 / sigma_m ** 2)
+ Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)
+ norm
)
return loglike
@ -421,9 +436,9 @@ def likelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,))
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float): Uncertainty on the stretch measurements
- sigma_c (float): Uncertainty on the colour measurements
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
- sigma_c (float or np.ndarray): Uncertainty on the colour measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
@ -475,9 +490,9 @@ def test_likelihood_scan(prior, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,))
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float): Uncertainty on the stretch measurements
- sigma_c (float): Uncertainty on the colour measurements
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
- sigma_c (float or np.ndarray): Uncertainty on the colour measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
@ -546,9 +561,9 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin,
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,))
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float): Uncertainty on the stretch measurements
- sigma_c (float): Uncertainty on the colour measurements
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
- sigma_c (float or np.ndarray): Uncertainty on the colour measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
@ -837,11 +852,4 @@ def main():
if __name__ == "__main__":
main()
"""
TO DO
- Deal with case where sigma_eta and sigma_m could be floats vs arrays
"""

View file

@ -509,7 +509,7 @@ def likelihood_m(m_true, m_obs, sigma_m, mthresh):
Args:
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- mthresh (float): Threshold absolute magnitude in selection
Returns:
@ -517,11 +517,20 @@ def likelihood_m(m_true, m_obs, sigma_m, mthresh):
"""
Nt = m_obs.shape[0]
norm = jnp.log(2) - jnp.log(jax.scipy.special.erfc(- (mthresh - m_true) / (jnp.sqrt(2) * sigma_m))) - 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)
norm0 = (
jnp.log(2)
- jnp.log(jax.scipy.special.erfc(- (mthresh - m_true) / (jnp.sqrt(2) * sigma_m)))
- 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)
)
norm1 = jnp.where(
jnp.ndim(sigma_m) == 0,
Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2),
jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2))
)
loglike = - (
0.5 * jnp.sum((m_obs - m_true) ** 2 / sigma_m ** 2)
- jnp.sum(norm)
+ Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)
- jnp.sum(norm0)
+ norm1
)
return loglike
@ -534,16 +543,21 @@ def likelihood_eta(eta_true, eta_obs, sigma_eta):
Args:
- eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,))
- eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,))
- sigma_eta (float): Uncertainty on the linewidth measurements
- sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
Nt = eta_obs.shape[0]
norm = jnp.where(
jnp.ndim(sigma_eta) == 0,
Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2),
jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2))
)
loglike = - (
0.5 * jnp.sum((eta_obs - eta_true) ** 2 / sigma_eta ** 2)
+ Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2)
+ norm
)
return loglike
@ -575,8 +589,8 @@ def likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, vbulk,
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,))
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_eta (float): Uncertainty on the linewidth measurements
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
- mthresh (float): Threshold absolute magnitude in selection
@ -624,8 +638,8 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true,
- czCMB (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,))
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_eta (float): Uncertainty on the apparent linewidth measurements
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_eta (float or np.ndarray): Uncertainty on the apparent linewidth measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
- mthresh (float): Threshold absolute magnitude in selection
@ -713,8 +727,8 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin,
- czCMB (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,))
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_eta (float): Uncertainty on the apparent linewidth measurements
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_eta (float or np.ndarray): Uncertainty on the apparent linewidth measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
- mthresh (float): Threshold absolute magnitude in selection
@ -976,10 +990,3 @@ def main():
if __name__ == "__main__":
main()
"""
TO DO
- Deal with case where sigma_eta and sigma_m could be floats vs arrays
"""