diff --git a/tests/sn_corner.png b/tests/sn_corner.png new file mode 100644 index 0000000..daa6c22 Binary files /dev/null and b/tests/sn_corner.png differ diff --git a/tests/sn_inference.py b/tests/sn_inference.py index faa5e60..c2c3f3f 100644 --- a/tests/sn_inference.py +++ b/tests/sn_inference.py @@ -7,12 +7,17 @@ import astropy.units as apu import jax.numpy as jnp import jax +import corner import matplotlib.pyplot as plt import borg_velocity.poisson_process as poisson_process import borg_velocity.projection as projection import borg_velocity.utils as utils +import numpyro +import numpyro.distributions as dist +from jax import lax, random + from tfr_inference import get_fields, generateMBData # Output stream management @@ -464,7 +469,218 @@ def test_likelihood_scan(prior, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v return +def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, interp_order, bias_epsilon, + czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos,): + """ + Run MCMC over the model parameters + + Args: + - num_warmup (int): Number of warmup steps to take in the MCMC + - num_samples (int): Number of samples to take in the MCMC + - prior + - initial + - dens (np.ndarray): Over-density field (shape = (N, N, N)) + - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) + - cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters to use + - L (float): Comoving box size (Mpc/h) + - xmin (float): Coordinate of corner of the box (Mpc/h) + - interp_order (int): Order of interpolation from grid points to the line of sight + - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# + - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) + - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) + - sigma_m (float): Uncertainty on the apparent magnitude measurements + + - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). + The shape is (3, Nt, Nsig) + + Returns: + - mcmc (numpyro.infer.MCMC): MCMC object which has been run + + """ + + Nt = stretch_obs.shape[0] + omega_m = cpar.omega_m + h = cpar.h + sigma_bulk = utils.get_sigma_bulk(L, cpar) + + + def sn_model(): + + alpha = numpyro.sample("alpha", dist.Uniform(*prior['alpha'])) + a_tripp = numpyro.sample("a_tripp", dist.Uniform(*prior['a_tripp'])) + b_tripp = numpyro.sample("b_tripp", dist.Uniform(*prior['b_tripp'])) + M_SN = numpyro.sample("M_SN", dist.Uniform(*prior['M_SN'])) + sigma_SN = numpyro.sample("sigma_SN", dist.HalfCauchy(1.0)) + sigma_v = numpyro.sample("sigma_v", dist.Uniform(*prior['sigma_v'])) + + hyper_mean_m = numpyro.sample("hyper_mean_m", dist.Uniform(*prior['hyper_mean_m'])) + hyper_sigma_m = numpyro.sample("hyper_sigma_m", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior + hyper_mean_stretch = numpyro.sample("hyper_mean_stretch", dist.Uniform(*prior['hyper_mean_stretch'])) + hyper_sigma_stretch = numpyro.sample("hyper_sigma_stretch", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior + hyper_mean_c = numpyro.sample("hyper_mean_c", dist.Uniform(*prior['hyper_mean_c'])) + hyper_sigma_c = numpyro.sample("hyper_sigma_c", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior + + # Sample correlation matrix using LKJ prior + L_corr = numpyro.sample("L_corr", dist.LKJCholesky(3, concentration=1.0)) # Cholesky factor of correlation matrix + corr_matrix = L_corr @ L_corr.T # Convert to full correlation matrix + + # Construct full covariance matrix: Σ = D * Corr * D + hyper_mean = jnp.array([hyper_mean_m, hyper_mean_stretch, hyper_mean_c]) + hyper_sigma = jnp.array([hyper_sigma_m, hyper_sigma_stretch, hyper_sigma_c]) + hyper_cov = jnp.diag(hyper_sigma) @ corr_matrix @ jnp.diag(hyper_sigma) + + # Sample m_true and eta_true + x = numpyro.sample("true_vars", dist.MultivariateNormal(hyper_mean, hyper_cov), sample_shape=(Nt,)) + m_true = numpyro.deterministic("m_true", x[:, 0]) + stretch_true = numpyro.deterministic("stretch_true", x[:, 1]) + c_true = numpyro.deterministic("c_true", x[:, 2]) + + # Sample bulk velocity + vbulk_x = numpyro.sample("vbulk_x", dist.Normal(0, sigma_bulk / jnp.sqrt(3))) + vbulk_y = numpyro.sample("vbulk_y", dist.Normal(0, sigma_bulk / jnp.sqrt(3))) + vbulk_z = numpyro.sample("vbulk_z", dist.Normal(0, sigma_bulk / jnp.sqrt(3))) + + # Evaluate the likelihood + numpyro.sample("obs", SNLikelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z), obs=jnp.array([m_obs, stretch_obs, c_obs])) + + + class SNLikelihood(dist.Distribution): + support = dist.constraints.real + + def __init__(self, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z): + self.alpha, self.a_tripp, self.b_tripp, self.M_SN, self.sigma_SN, self.sigma_v, self.m_true, self.stretch_true, self.c_true, self.vbulk_x, self.vbulk_y, self.vbulk_z = dist.util.promote_shapes(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z) + batch_shape = lax.broadcast_shapes( + jnp.shape(alpha), + jnp.shape(a_tripp), + jnp.shape(b_tripp), + jnp.shape(M_SN), + jnp.shape(sigma_SN), + jnp.shape(sigma_v), + jnp.shape(m_true), + jnp.shape(stretch_true), + jnp.shape(c_true), + jnp.shape(vbulk_x), + jnp.shape(vbulk_y), + jnp.shape(vbulk_z), + ) + super(SNLikelihood, self).__init__(batch_shape = batch_shape) + + def sample(self, key, sample_shape=()): + raise NotImplementedError + + def log_prob(self, value): + vbulk = jnp.array([self.vbulk_x, self.vbulk_y, self.vbulk_z]) + loglike = likelihood(self.alpha, self.a_tripp, self.b_tripp, self.M_SN, self.sigma_SN, self.sigma_v, + self.m_true, self.stretch_true, self.c_true, vbulk, + dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, + czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos) + return loglike + + rng_key = random.PRNGKey(6) + rng_key, rng_key_ = random.split(rng_key) + values = initial + values['true_vars'] = jnp.array([m_obs, stretch_obs, c_obs]).T + values['L_corr'] = jnp.identity(3) + values['vbulk_x'] = 0. + values['vbulk_y'] = 0. + values['vbulk_z'] = 0. + myprint('Preparing MCMC kernel') + kernel = numpyro.infer.NUTS(sn_model, + init_strategy=numpyro.infer.initialization.init_to_value(values=initial) + ) + mcmc = numpyro.infer.MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) + myprint('Running MCMC') + mcmc.run(rng_key_) + mcmc.print_summary() + + return mcmc + + + +def process_mcmc_run(mcmc, param_labels, truths, true_vars): + """ + Make summary plots from the MCMC and save these to file + + Args: + - mcmc (numpyro.infer.MCMC): MCMC object which has been run + - param_labels (list[str]): Names of the parameters to plot + - truths (list[float]): True values of the parameters to plot. If unknown, then entry is None + - true_vars (dict): True values of the observables to compare against inferred ones + """ + + # Convert samples into a single array + samples = mcmc.get_samples() + + samps = jnp.empty((len(samples[param_labels[0]]), len(param_labels))) + for i, p in enumerate(param_labels): + if p.startswith('hyper_corr'): + L_corr = samples['L_corr'] + corr_matrix = jnp.matmul(L_corr, jnp.transpose(L_corr, (0, 2, 1))) + if p == 'hyper_corr_mx': + samps = samps.at[:,i].set(corr_matrix[:,0,1]) + elif p == 'hyper_corr_mc': + samps = samps.at[:,i].set(corr_matrix[:,0,2]) + elif p == 'hyper_corr_xc': + samps = samps.at[:,i].set(corr_matrix[:,1,2]) + else: + raise NotImplementedError + else: + samps = samps.at[:,i].set(samples[p]) + + # Trace plot of non-redshift quantities + fig1, axs1 = plt.subplots(samps.shape[1], 1, figsize=(6,3*samps.shape[1]), sharex=True) + axs1 = np.atleast_1d(axs1) + for i in range(samps.shape[1]): + axs1[i].plot(samps[:,i]) + axs1[i].set_ylabel(param_labels[i]) + if truths[i] is not None: + axs1[i].axhline(truths[i], color='k') + axs1[-1].set_xlabel('Step Number') + fig1.tight_layout() + fig1.savefig('sn_trace.png') + + # Corner plot + fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(20,20)) + corner.corner( + np.array(samps), + labels=param_labels, + fig=fig2, + truths=truths + ) + fig2.savefig('sn_corner.png') + + # True vs predicted + for var in ['stretch', 'c', 'm']: + vname = var + '_true' + if vname in samples.keys(): + xtrue = true_vars[var] + xpred_median = np.median(samples[vname], axis=0) + xpred_plus = np.percentile(samples[vname], 84, axis=0) - xpred_median + xpred_minus = xpred_median - np.percentile(samples[vname], 16, axis=0) + + fig3, axs3 = plt.subplots(2, 1, figsize=(10,8), sharex=True) + plot_kwargs = {'fmt':'.', 'markersize':3, 'zorder':10, + 'capsize':1, 'elinewidth':1, 'alpha':1} + axs3[0].errorbar(xtrue, xpred_median, yerr=[xpred_minus, xpred_plus], **plot_kwargs) + axs3[1].errorbar(xtrue, xpred_median - xtrue, yerr=[xpred_minus, xpred_plus], **plot_kwargs) + axs3[1].set_xlabel('True') + axs3[0].set_ylabel('Predicted') + axs3[1].set_ylabel('Predicted - True') + xlim = axs3[0].get_xlim() + ylim = axs3[0].get_ylim() + axs3[0].plot(xlim, xlim, color='k', zorder=0) + axs3[0].set_xlim(xlim) + axs3[0].set_ylim(ylim) + axs3[1].axhline(0, color='k', zorder=0) + fig3.suptitle(var) + fig3.align_labels() + fig3.tight_layout() + fig3.savefig(f'sn_true_predicted_{var}.png') + + return + + def main(): @@ -493,15 +709,8 @@ def main(): M_SN = - 18.558 sigma_SN = 0.082 - prior = { - 'alpha': [0.5, 4.5], - 'a_tripp': [0.01, 0.2], - 'b_tripp': [2.5, 4.5], - 'M_SN': [-19.5, -17.5], - 'hyper_mean_stretch': [hyper_stretch_mu - hyper_stretch_sigma, hyper_stretch_mu + hyper_stretch_sigma], - 'hyper_mean_c':[hyper_c_mu - hyper_c_sigma, hyper_c_mu + hyper_c_sigma], - 'sigma_v': [10, 3000], - } + num_warmup = 1000 + num_samples = 2000 # Make mock np.random.seed(123) @@ -513,6 +722,30 @@ def main(): sigma_v, interp_order=interp_order, bias_epsilon=bias_epsilon) MB_pos = generateMBData(RA, Dec, czCMB, L, N, R_lim, Nsig, Nint_points, sigma_v, frac_sigma_r) + initial = { + 'a_tripp': a_tripp, + 'b_tripp': b_tripp, + 'M_SN': M_SN, + 'sigma_SN': sigma_SN, + 'sigma_v': sigma_v, + 'hyper_mean_stretch': hyper_stretch_mu, + 'hyper_sigma_stretch': hyper_stretch_sigma, + 'hyper_mean_c': hyper_c_mu, + 'hyper_sigma_c': hyper_c_sigma, + 'hyper_mean_m': np.median(m_obs), + 'hyper_sigma_m': (np.percentile(m_obs, 84) - np.percentile(m_obs, 16)) / 2, + } + prior = { + 'alpha': [0.5, 4.5], + 'a_tripp': [0.01, 0.2], + 'b_tripp': [2.5, 4.5], + 'M_SN': [-19.5, -17.5], + 'hyper_mean_stretch': [hyper_stretch_mu - hyper_stretch_sigma, hyper_stretch_mu + hyper_stretch_sigma], + 'hyper_mean_c':[hyper_c_mu - hyper_c_sigma, hyper_c_mu + hyper_c_sigma], + 'hyper_mean_m':[initial['hyper_mean_m'] - initial['hyper_sigma_m'], initial['hyper_mean_m'] + initial['hyper_sigma_m']], + 'sigma_v': [10, 3000], + } + # Test likelihood loglike = likelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk, dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon, @@ -525,9 +758,32 @@ def main(): czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos) + # Run a MCMC + mcmc = run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, interp_order, bias_epsilon, + czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos,) + param_labels = ['alpha', 'a_tripp', 'b_tripp', 'M_SN', 'sigma_SN', 'sigma_v', + 'hyper_mean_m', 'hyper_sigma_m', 'hyper_mean_stretch', 'hyper_sigma_stretch', + 'hyper_mean_c', 'hyper_sigma_c', 'hyper_corr_mx', 'hyper_corr_mc', 'hyper_corr_xc', + 'vbulk_x', 'vbulk_y', 'vbulk_z'] + truths = [alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, + None, None, hyper_stretch_mu, hyper_stretch_sigma, + hyper_c_mu, hyper_c_sigma, None, None, None, + vbulk[0], vbulk[1], vbulk[2]] + true_vars = {'m':m_true, 'stretch':stretch_true, 'c': c_true} + process_mcmc_run(mcmc, param_labels, truths, true_vars) + return if __name__ == "__main__": main() - \ No newline at end of file + + +""" +TO DO + +- Fix SN inference - poor sampling and Tripp variables not constrained +- Deal with case where sigma_eta and sigma_m could be floats vs arrays +- Add in selection cuts for the supernovae + +""" \ No newline at end of file diff --git a/tests/sn_trace.png b/tests/sn_trace.png new file mode 100644 index 0000000..3ee4fde Binary files /dev/null and b/tests/sn_trace.png differ diff --git a/tests/sn_true_predicted_c.png b/tests/sn_true_predicted_c.png new file mode 100644 index 0000000..a0d1648 Binary files /dev/null and b/tests/sn_true_predicted_c.png differ diff --git a/tests/sn_true_predicted_m.png b/tests/sn_true_predicted_m.png new file mode 100644 index 0000000..af3c5fa Binary files /dev/null and b/tests/sn_true_predicted_m.png differ diff --git a/tests/sn_true_predicted_stretch.png b/tests/sn_true_predicted_stretch.png new file mode 100644 index 0000000..4c90964 Binary files /dev/null and b/tests/sn_true_predicted_stretch.png differ diff --git a/tests/tfr_corner.png b/tests/tfr_corner.png new file mode 100644 index 0000000..287d7b5 Binary files /dev/null and b/tests/tfr_corner.png differ diff --git a/tests/tfr_inference.py b/tests/tfr_inference.py index 1a76f96..9eeec8b 100644 --- a/tests/tfr_inference.py +++ b/tests/tfr_inference.py @@ -666,7 +666,7 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, plt.axhline(orig_ll, ls='--', color='k') plt.xlabel(name) plt.ylabel('Negative log-likelihood') - plt.savefig(f'likelihood_scan_{name}.png') + plt.savefig(f'tfr_likelihood_scan_{name}.png') fig = plt.gcf() plt.clf() plt.close(fig) @@ -685,7 +685,7 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, plt.axhline(orig_ll, ls='--', color='k') plt.xlabel('mthresh') plt.ylabel('Negative log-likelihood') - plt.savefig(f'likelihood_scan_mthresh.png') + plt.savefig(f'tfr_likelihood_scan_mthresh.png') fig = plt.gcf() plt.clf() plt.close(fig) @@ -763,22 +763,20 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, vbulk_z = numpyro.sample("vbulk_z", dist.Normal(0, sigma_bulk / jnp.sqrt(3))) # Evaluate the likelihood - numpyro.sample("obs", TFRLikelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z), obs=jnp.array([m_obs, eta_obs])) + numpyro.sample("obs", TFRLikelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z), obs=jnp.array([m_obs, eta_obs])) class TFRLikelihood(dist.Distribution): support = dist.constraints.real - def __init__(self, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z): - self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, self.hyper_mean_eta, self.hyper_sigma_eta, self.m_true, self.eta_true, self.vbulk_x, self.vbulk_y, self.vbulk_z = dist.util.promote_shapes(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z) + def __init__(self, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z): + self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, self.m_true, self.eta_true, self.vbulk_x, self.vbulk_y, self.vbulk_z = dist.util.promote_shapes(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z) batch_shape = lax.broadcast_shapes( jnp.shape(alpha), jnp.shape(a_TFR), jnp.shape(b_TFR), jnp.shape(sigma_TFR), jnp.shape(sigma_v), - jnp.shape(hyper_mean_eta), - jnp.shape(hyper_sigma_eta), jnp.shape(m_true), jnp.shape(eta_true), jnp.shape(vbulk_x), @@ -851,7 +849,7 @@ def process_mcmc_run(mcmc, param_labels, truths, true_vars): axs1[i].axhline(truths[i], color='k') axs1[-1].set_xlabel('Step Number') fig1.tight_layout() - fig1.savefig('trace.png') + fig1.savefig('tfr_trace.png') # Corner plot fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(20,20)) @@ -861,7 +859,7 @@ def process_mcmc_run(mcmc, param_labels, truths, true_vars): fig=fig2, truths=truths ) - fig2.savefig('corner.png') + fig2.savefig('tfr_corner.png') # True vs predicted for var in ['eta', 'm']: @@ -889,7 +887,7 @@ def process_mcmc_run(mcmc, param_labels, truths, true_vars): fig3.suptitle(var) fig3.align_labels() fig3.tight_layout() - fig3.savefig(f'true_predicted_{var}.png') + fig3.savefig(f'tfr_true_predicted_{var}.png') return diff --git a/tests/likelihood_scan_a_TFR.png b/tests/tfr_likelihood_scan_a_TFR.png similarity index 100% rename from tests/likelihood_scan_a_TFR.png rename to tests/tfr_likelihood_scan_a_TFR.png diff --git a/tests/likelihood_scan_alpha.png b/tests/tfr_likelihood_scan_alpha.png similarity index 100% rename from tests/likelihood_scan_alpha.png rename to tests/tfr_likelihood_scan_alpha.png diff --git a/tests/likelihood_scan_b_TFR.png b/tests/tfr_likelihood_scan_b_TFR.png similarity index 100% rename from tests/likelihood_scan_b_TFR.png rename to tests/tfr_likelihood_scan_b_TFR.png diff --git a/tests/likelihood_scan_mthresh.png b/tests/tfr_likelihood_scan_mthresh.png similarity index 100% rename from tests/likelihood_scan_mthresh.png rename to tests/tfr_likelihood_scan_mthresh.png diff --git a/tests/likelihood_scan_sigma_TFR.png b/tests/tfr_likelihood_scan_sigma_TFR.png similarity index 100% rename from tests/likelihood_scan_sigma_TFR.png rename to tests/tfr_likelihood_scan_sigma_TFR.png diff --git a/tests/likelihood_scan_sigma_v.png b/tests/tfr_likelihood_scan_sigma_v.png similarity index 100% rename from tests/likelihood_scan_sigma_v.png rename to tests/tfr_likelihood_scan_sigma_v.png diff --git a/tests/tfr_trace.png b/tests/tfr_trace.png new file mode 100644 index 0000000..17a126c Binary files /dev/null and b/tests/tfr_trace.png differ diff --git a/tests/tfr_true_predicted_eta.png b/tests/tfr_true_predicted_eta.png new file mode 100644 index 0000000..7797e2a Binary files /dev/null and b/tests/tfr_true_predicted_eta.png differ diff --git a/tests/tfr_true_predicted_m.png b/tests/tfr_true_predicted_m.png new file mode 100644 index 0000000..217325e Binary files /dev/null and b/tests/tfr_true_predicted_m.png differ