From 071beeae8071b33d1c86e189fc6e2d16e74dcb83 Mon Sep 17 00:00:00 2001 From: Deaglan Bartlett Date: Sun, 9 Feb 2025 12:00:35 +0100 Subject: [PATCH] Enable scalar or array values for measurement errors in TFR and SN tests --- tests/sn_inference.py | 54 ++++++++++++++++++++++++------------------ tests/tfr_inference.py | 45 ++++++++++++++++++++--------------- 2 files changed, 57 insertions(+), 42 deletions(-) diff --git a/tests/sn_inference.py b/tests/sn_inference.py index d4b6ee4..d86ec24 100644 --- a/tests/sn_inference.py +++ b/tests/sn_inference.py @@ -333,16 +333,21 @@ def likelihood_stretch(stretch_true, stretch_obs, sigma_stretch): Args: - stretch_true (np.ndarray): True stretch of the tracers (shape = (Nt,)) - stretch_obs (np.ndarray): Observed stretch of the tracers (shape = (Nt,)) - - sigma_stretch (float): Uncertainty on the stretch measurements + - sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements Returns: - loglike (float): The log-likelihood of the data """ Nt = stretch_obs.shape[0] + norm = jnp.where( + jnp.ndim(sigma_stretch) == 0, + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2), + jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2)) + ) loglike = - ( 0.5 * jnp.sum((stretch_obs - stretch_true) ** 2 / sigma_stretch ** 2) - + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2) + + norm ) return loglike @@ -355,16 +360,21 @@ def likelihood_c(c_true, c_obs, sigma_c): Args: - c_true (np.ndarray): True colours of the tracers (shape = (Nt,)) - c_obs (np.ndarray): Observed colours of the tracers (shape = (Nt,)) - - sigma_c (float): Uncertainty on the colours measurements + - sigma_c (float or np.ndarray): Uncertainty on the colours measurements Returns: - loglike (float): The log-likelihood of the data """ Nt = c_obs.shape[0] + norm = jnp.where( + jnp.ndim(sigma_c) == 0, + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2), + jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2)) + ) loglike = - ( 0.5 * jnp.sum((c_obs - c_true) ** 2 / sigma_c ** 2) - + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2) + + norm ) return loglike @@ -377,16 +387,21 @@ def likelihood_m(m_true, m_obs, sigma_m): Args: - m_true (np.ndarray): True apparent magnitude of the tracers (shape = (Nt,)) - m_obs (np.ndarray): Observed apparent magnitude of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements Returns: - loglike (float): The log-likelihood of the data """ Nt = m_obs.shape[0] + norm = jnp.where( + jnp.ndim(sigma_m) == 0, + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2), + jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)) + ) loglike = - ( 0.5 * jnp.sum((m_obs - m_true) ** 2 / sigma_m ** 2) - + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2) + + norm ) return loglike @@ -421,9 +436,9 @@ def likelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,)) - c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements - - sigma_stretch (float): Uncertainty on the stretch measurements - - sigma_c (float): Uncertainty on the colour measurements + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements + - sigma_c (float or np.ndarray): Uncertainty on the colour measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) @@ -475,9 +490,9 @@ def test_likelihood_scan(prior, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,)) - c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements - - sigma_stretch (float): Uncertainty on the stretch measurements - - sigma_c (float): Uncertainty on the colour measurements + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements + - sigma_c (float or np.ndarray): Uncertainty on the colour measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) @@ -546,9 +561,9 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,)) - c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements - - sigma_stretch (float): Uncertainty on the stretch measurements - - sigma_c (float): Uncertainty on the colour measurements + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements + - sigma_c (float or np.ndarray): Uncertainty on the colour measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) @@ -837,11 +852,4 @@ def main(): if __name__ == "__main__": main() - - -""" -TO DO - -- Deal with case where sigma_eta and sigma_m could be floats vs arrays - -""" \ No newline at end of file + \ No newline at end of file diff --git a/tests/tfr_inference.py b/tests/tfr_inference.py index f41863a..7fb37dc 100644 --- a/tests/tfr_inference.py +++ b/tests/tfr_inference.py @@ -509,7 +509,7 @@ def likelihood_m(m_true, m_obs, sigma_m, mthresh): Args: - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements - mthresh (float): Threshold absolute magnitude in selection Returns: @@ -517,11 +517,20 @@ def likelihood_m(m_true, m_obs, sigma_m, mthresh): """ Nt = m_obs.shape[0] - norm = jnp.log(2) - jnp.log(jax.scipy.special.erfc(- (mthresh - m_true) / (jnp.sqrt(2) * sigma_m))) - 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2) + norm0 = ( + jnp.log(2) + - jnp.log(jax.scipy.special.erfc(- (mthresh - m_true) / (jnp.sqrt(2) * sigma_m))) + - 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2) + ) + norm1 = jnp.where( + jnp.ndim(sigma_m) == 0, + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2), + jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)) + ) loglike = - ( 0.5 * jnp.sum((m_obs - m_true) ** 2 / sigma_m ** 2) - - jnp.sum(norm) - + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2) + - jnp.sum(norm0) + + norm1 ) return loglike @@ -534,16 +543,21 @@ def likelihood_eta(eta_true, eta_obs, sigma_eta): Args: - eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,)) - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) - - sigma_eta (float): Uncertainty on the linewidth measurements + - sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements Returns: - loglike (float): The log-likelihood of the data """ Nt = eta_obs.shape[0] + norm = jnp.where( + jnp.ndim(sigma_eta) == 0, + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2), + jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2)) + ) loglike = - ( 0.5 * jnp.sum((eta_obs - eta_true) ** 2 / sigma_eta ** 2) - + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2) + + norm ) return loglike @@ -575,8 +589,8 @@ def likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, vbulk, - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements - - sigma_eta (float): Uncertainty on the linewidth measurements + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) - mthresh (float): Threshold absolute magnitude in selection @@ -624,8 +638,8 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, - czCMB (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements - - sigma_eta (float): Uncertainty on the apparent linewidth measurements + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - sigma_eta (float or np.ndarray): Uncertainty on the apparent linewidth measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) - mthresh (float): Threshold absolute magnitude in selection @@ -713,8 +727,8 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, - czCMB (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) - - sigma_m (float): Uncertainty on the apparent magnitude measurements - - sigma_eta (float): Uncertainty on the apparent linewidth measurements + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - sigma_eta (float or np.ndarray): Uncertainty on the apparent linewidth measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) - mthresh (float): Threshold absolute magnitude in selection @@ -976,10 +990,3 @@ def main(): if __name__ == "__main__": main() - -""" -TO DO - -- Deal with case where sigma_eta and sigma_m could be floats vs arrays - -""" \ No newline at end of file