First attempt at SN inference test
BIN
tests/sn_corner.png
Normal file
After Width: | Height: | Size: 1.4 MiB |
|
@ -7,12 +7,17 @@ import astropy.units as apu
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax
|
import jax
|
||||||
|
|
||||||
|
import corner
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import borg_velocity.poisson_process as poisson_process
|
import borg_velocity.poisson_process as poisson_process
|
||||||
import borg_velocity.projection as projection
|
import borg_velocity.projection as projection
|
||||||
import borg_velocity.utils as utils
|
import borg_velocity.utils as utils
|
||||||
|
|
||||||
|
import numpyro
|
||||||
|
import numpyro.distributions as dist
|
||||||
|
from jax import lax, random
|
||||||
|
|
||||||
from tfr_inference import get_fields, generateMBData
|
from tfr_inference import get_fields, generateMBData
|
||||||
|
|
||||||
# Output stream management
|
# Output stream management
|
||||||
|
@ -464,7 +469,218 @@ def test_likelihood_scan(prior, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, interp_order, bias_epsilon,
|
||||||
|
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos,):
|
||||||
|
"""
|
||||||
|
Run MCMC over the model parameters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- num_warmup (int): Number of warmup steps to take in the MCMC
|
||||||
|
- num_samples (int): Number of samples to take in the MCMC
|
||||||
|
- prior
|
||||||
|
- initial
|
||||||
|
- dens (np.ndarray): Over-density field (shape = (N, N, N))
|
||||||
|
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
|
||||||
|
- cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters to use
|
||||||
|
- L (float): Comoving box size (Mpc/h)
|
||||||
|
- xmin (float): Coordinate of corner of the box (Mpc/h)
|
||||||
|
- interp_order (int): Order of interpolation from grid points to the line of sight
|
||||||
|
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
|
||||||
|
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
|
||||||
|
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
|
||||||
|
|
||||||
|
- sigma_m (float): Uncertainty on the apparent magnitude measurements
|
||||||
|
|
||||||
|
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
|
||||||
|
The shape is (3, Nt, Nsig)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- mcmc (numpyro.infer.MCMC): MCMC object which has been run
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
Nt = stretch_obs.shape[0]
|
||||||
|
omega_m = cpar.omega_m
|
||||||
|
h = cpar.h
|
||||||
|
sigma_bulk = utils.get_sigma_bulk(L, cpar)
|
||||||
|
|
||||||
|
|
||||||
|
def sn_model():
|
||||||
|
|
||||||
|
alpha = numpyro.sample("alpha", dist.Uniform(*prior['alpha']))
|
||||||
|
a_tripp = numpyro.sample("a_tripp", dist.Uniform(*prior['a_tripp']))
|
||||||
|
b_tripp = numpyro.sample("b_tripp", dist.Uniform(*prior['b_tripp']))
|
||||||
|
M_SN = numpyro.sample("M_SN", dist.Uniform(*prior['M_SN']))
|
||||||
|
sigma_SN = numpyro.sample("sigma_SN", dist.HalfCauchy(1.0))
|
||||||
|
sigma_v = numpyro.sample("sigma_v", dist.Uniform(*prior['sigma_v']))
|
||||||
|
|
||||||
|
hyper_mean_m = numpyro.sample("hyper_mean_m", dist.Uniform(*prior['hyper_mean_m']))
|
||||||
|
hyper_sigma_m = numpyro.sample("hyper_sigma_m", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior
|
||||||
|
hyper_mean_stretch = numpyro.sample("hyper_mean_stretch", dist.Uniform(*prior['hyper_mean_stretch']))
|
||||||
|
hyper_sigma_stretch = numpyro.sample("hyper_sigma_stretch", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior
|
||||||
|
hyper_mean_c = numpyro.sample("hyper_mean_c", dist.Uniform(*prior['hyper_mean_c']))
|
||||||
|
hyper_sigma_c = numpyro.sample("hyper_sigma_c", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior
|
||||||
|
|
||||||
|
# Sample correlation matrix using LKJ prior
|
||||||
|
L_corr = numpyro.sample("L_corr", dist.LKJCholesky(3, concentration=1.0)) # Cholesky factor of correlation matrix
|
||||||
|
corr_matrix = L_corr @ L_corr.T # Convert to full correlation matrix
|
||||||
|
|
||||||
|
# Construct full covariance matrix: Σ = D * Corr * D
|
||||||
|
hyper_mean = jnp.array([hyper_mean_m, hyper_mean_stretch, hyper_mean_c])
|
||||||
|
hyper_sigma = jnp.array([hyper_sigma_m, hyper_sigma_stretch, hyper_sigma_c])
|
||||||
|
hyper_cov = jnp.diag(hyper_sigma) @ corr_matrix @ jnp.diag(hyper_sigma)
|
||||||
|
|
||||||
|
# Sample m_true and eta_true
|
||||||
|
x = numpyro.sample("true_vars", dist.MultivariateNormal(hyper_mean, hyper_cov), sample_shape=(Nt,))
|
||||||
|
m_true = numpyro.deterministic("m_true", x[:, 0])
|
||||||
|
stretch_true = numpyro.deterministic("stretch_true", x[:, 1])
|
||||||
|
c_true = numpyro.deterministic("c_true", x[:, 2])
|
||||||
|
|
||||||
|
# Sample bulk velocity
|
||||||
|
vbulk_x = numpyro.sample("vbulk_x", dist.Normal(0, sigma_bulk / jnp.sqrt(3)))
|
||||||
|
vbulk_y = numpyro.sample("vbulk_y", dist.Normal(0, sigma_bulk / jnp.sqrt(3)))
|
||||||
|
vbulk_z = numpyro.sample("vbulk_z", dist.Normal(0, sigma_bulk / jnp.sqrt(3)))
|
||||||
|
|
||||||
|
# Evaluate the likelihood
|
||||||
|
numpyro.sample("obs", SNLikelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z), obs=jnp.array([m_obs, stretch_obs, c_obs]))
|
||||||
|
|
||||||
|
|
||||||
|
class SNLikelihood(dist.Distribution):
|
||||||
|
support = dist.constraints.real
|
||||||
|
|
||||||
|
def __init__(self, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z):
|
||||||
|
self.alpha, self.a_tripp, self.b_tripp, self.M_SN, self.sigma_SN, self.sigma_v, self.m_true, self.stretch_true, self.c_true, self.vbulk_x, self.vbulk_y, self.vbulk_z = dist.util.promote_shapes(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z)
|
||||||
|
batch_shape = lax.broadcast_shapes(
|
||||||
|
jnp.shape(alpha),
|
||||||
|
jnp.shape(a_tripp),
|
||||||
|
jnp.shape(b_tripp),
|
||||||
|
jnp.shape(M_SN),
|
||||||
|
jnp.shape(sigma_SN),
|
||||||
|
jnp.shape(sigma_v),
|
||||||
|
jnp.shape(m_true),
|
||||||
|
jnp.shape(stretch_true),
|
||||||
|
jnp.shape(c_true),
|
||||||
|
jnp.shape(vbulk_x),
|
||||||
|
jnp.shape(vbulk_y),
|
||||||
|
jnp.shape(vbulk_z),
|
||||||
|
)
|
||||||
|
super(SNLikelihood, self).__init__(batch_shape = batch_shape)
|
||||||
|
|
||||||
|
def sample(self, key, sample_shape=()):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def log_prob(self, value):
|
||||||
|
vbulk = jnp.array([self.vbulk_x, self.vbulk_y, self.vbulk_z])
|
||||||
|
loglike = likelihood(self.alpha, self.a_tripp, self.b_tripp, self.M_SN, self.sigma_SN, self.sigma_v,
|
||||||
|
self.m_true, self.stretch_true, self.c_true, vbulk,
|
||||||
|
dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon,
|
||||||
|
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
|
||||||
|
return loglike
|
||||||
|
|
||||||
|
rng_key = random.PRNGKey(6)
|
||||||
|
rng_key, rng_key_ = random.split(rng_key)
|
||||||
|
values = initial
|
||||||
|
values['true_vars'] = jnp.array([m_obs, stretch_obs, c_obs]).T
|
||||||
|
values['L_corr'] = jnp.identity(3)
|
||||||
|
values['vbulk_x'] = 0.
|
||||||
|
values['vbulk_y'] = 0.
|
||||||
|
values['vbulk_z'] = 0.
|
||||||
|
myprint('Preparing MCMC kernel')
|
||||||
|
kernel = numpyro.infer.NUTS(sn_model,
|
||||||
|
init_strategy=numpyro.infer.initialization.init_to_value(values=initial)
|
||||||
|
)
|
||||||
|
mcmc = numpyro.infer.MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
|
||||||
|
myprint('Running MCMC')
|
||||||
|
mcmc.run(rng_key_)
|
||||||
|
mcmc.print_summary()
|
||||||
|
|
||||||
|
return mcmc
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_mcmc_run(mcmc, param_labels, truths, true_vars):
|
||||||
|
"""
|
||||||
|
Make summary plots from the MCMC and save these to file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- mcmc (numpyro.infer.MCMC): MCMC object which has been run
|
||||||
|
- param_labels (list[str]): Names of the parameters to plot
|
||||||
|
- truths (list[float]): True values of the parameters to plot. If unknown, then entry is None
|
||||||
|
- true_vars (dict): True values of the observables to compare against inferred ones
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Convert samples into a single array
|
||||||
|
samples = mcmc.get_samples()
|
||||||
|
|
||||||
|
samps = jnp.empty((len(samples[param_labels[0]]), len(param_labels)))
|
||||||
|
for i, p in enumerate(param_labels):
|
||||||
|
if p.startswith('hyper_corr'):
|
||||||
|
L_corr = samples['L_corr']
|
||||||
|
corr_matrix = jnp.matmul(L_corr, jnp.transpose(L_corr, (0, 2, 1)))
|
||||||
|
if p == 'hyper_corr_mx':
|
||||||
|
samps = samps.at[:,i].set(corr_matrix[:,0,1])
|
||||||
|
elif p == 'hyper_corr_mc':
|
||||||
|
samps = samps.at[:,i].set(corr_matrix[:,0,2])
|
||||||
|
elif p == 'hyper_corr_xc':
|
||||||
|
samps = samps.at[:,i].set(corr_matrix[:,1,2])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
samps = samps.at[:,i].set(samples[p])
|
||||||
|
|
||||||
|
# Trace plot of non-redshift quantities
|
||||||
|
fig1, axs1 = plt.subplots(samps.shape[1], 1, figsize=(6,3*samps.shape[1]), sharex=True)
|
||||||
|
axs1 = np.atleast_1d(axs1)
|
||||||
|
for i in range(samps.shape[1]):
|
||||||
|
axs1[i].plot(samps[:,i])
|
||||||
|
axs1[i].set_ylabel(param_labels[i])
|
||||||
|
if truths[i] is not None:
|
||||||
|
axs1[i].axhline(truths[i], color='k')
|
||||||
|
axs1[-1].set_xlabel('Step Number')
|
||||||
|
fig1.tight_layout()
|
||||||
|
fig1.savefig('sn_trace.png')
|
||||||
|
|
||||||
|
# Corner plot
|
||||||
|
fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(20,20))
|
||||||
|
corner.corner(
|
||||||
|
np.array(samps),
|
||||||
|
labels=param_labels,
|
||||||
|
fig=fig2,
|
||||||
|
truths=truths
|
||||||
|
)
|
||||||
|
fig2.savefig('sn_corner.png')
|
||||||
|
|
||||||
|
# True vs predicted
|
||||||
|
for var in ['stretch', 'c', 'm']:
|
||||||
|
vname = var + '_true'
|
||||||
|
if vname in samples.keys():
|
||||||
|
xtrue = true_vars[var]
|
||||||
|
xpred_median = np.median(samples[vname], axis=0)
|
||||||
|
xpred_plus = np.percentile(samples[vname], 84, axis=0) - xpred_median
|
||||||
|
xpred_minus = xpred_median - np.percentile(samples[vname], 16, axis=0)
|
||||||
|
|
||||||
|
fig3, axs3 = plt.subplots(2, 1, figsize=(10,8), sharex=True)
|
||||||
|
plot_kwargs = {'fmt':'.', 'markersize':3, 'zorder':10,
|
||||||
|
'capsize':1, 'elinewidth':1, 'alpha':1}
|
||||||
|
axs3[0].errorbar(xtrue, xpred_median, yerr=[xpred_minus, xpred_plus], **plot_kwargs)
|
||||||
|
axs3[1].errorbar(xtrue, xpred_median - xtrue, yerr=[xpred_minus, xpred_plus], **plot_kwargs)
|
||||||
|
axs3[1].set_xlabel('True')
|
||||||
|
axs3[0].set_ylabel('Predicted')
|
||||||
|
axs3[1].set_ylabel('Predicted - True')
|
||||||
|
xlim = axs3[0].get_xlim()
|
||||||
|
ylim = axs3[0].get_ylim()
|
||||||
|
axs3[0].plot(xlim, xlim, color='k', zorder=0)
|
||||||
|
axs3[0].set_xlim(xlim)
|
||||||
|
axs3[0].set_ylim(ylim)
|
||||||
|
axs3[1].axhline(0, color='k', zorder=0)
|
||||||
|
fig3.suptitle(var)
|
||||||
|
fig3.align_labels()
|
||||||
|
fig3.tight_layout()
|
||||||
|
fig3.savefig(f'sn_true_predicted_{var}.png')
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
|
@ -493,15 +709,8 @@ def main():
|
||||||
M_SN = - 18.558
|
M_SN = - 18.558
|
||||||
sigma_SN = 0.082
|
sigma_SN = 0.082
|
||||||
|
|
||||||
prior = {
|
num_warmup = 1000
|
||||||
'alpha': [0.5, 4.5],
|
num_samples = 2000
|
||||||
'a_tripp': [0.01, 0.2],
|
|
||||||
'b_tripp': [2.5, 4.5],
|
|
||||||
'M_SN': [-19.5, -17.5],
|
|
||||||
'hyper_mean_stretch': [hyper_stretch_mu - hyper_stretch_sigma, hyper_stretch_mu + hyper_stretch_sigma],
|
|
||||||
'hyper_mean_c':[hyper_c_mu - hyper_c_sigma, hyper_c_mu + hyper_c_sigma],
|
|
||||||
'sigma_v': [10, 3000],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Make mock
|
# Make mock
|
||||||
np.random.seed(123)
|
np.random.seed(123)
|
||||||
|
@ -513,6 +722,30 @@ def main():
|
||||||
sigma_v, interp_order=interp_order, bias_epsilon=bias_epsilon)
|
sigma_v, interp_order=interp_order, bias_epsilon=bias_epsilon)
|
||||||
MB_pos = generateMBData(RA, Dec, czCMB, L, N, R_lim, Nsig, Nint_points, sigma_v, frac_sigma_r)
|
MB_pos = generateMBData(RA, Dec, czCMB, L, N, R_lim, Nsig, Nint_points, sigma_v, frac_sigma_r)
|
||||||
|
|
||||||
|
initial = {
|
||||||
|
'a_tripp': a_tripp,
|
||||||
|
'b_tripp': b_tripp,
|
||||||
|
'M_SN': M_SN,
|
||||||
|
'sigma_SN': sigma_SN,
|
||||||
|
'sigma_v': sigma_v,
|
||||||
|
'hyper_mean_stretch': hyper_stretch_mu,
|
||||||
|
'hyper_sigma_stretch': hyper_stretch_sigma,
|
||||||
|
'hyper_mean_c': hyper_c_mu,
|
||||||
|
'hyper_sigma_c': hyper_c_sigma,
|
||||||
|
'hyper_mean_m': np.median(m_obs),
|
||||||
|
'hyper_sigma_m': (np.percentile(m_obs, 84) - np.percentile(m_obs, 16)) / 2,
|
||||||
|
}
|
||||||
|
prior = {
|
||||||
|
'alpha': [0.5, 4.5],
|
||||||
|
'a_tripp': [0.01, 0.2],
|
||||||
|
'b_tripp': [2.5, 4.5],
|
||||||
|
'M_SN': [-19.5, -17.5],
|
||||||
|
'hyper_mean_stretch': [hyper_stretch_mu - hyper_stretch_sigma, hyper_stretch_mu + hyper_stretch_sigma],
|
||||||
|
'hyper_mean_c':[hyper_c_mu - hyper_c_sigma, hyper_c_mu + hyper_c_sigma],
|
||||||
|
'hyper_mean_m':[initial['hyper_mean_m'] - initial['hyper_sigma_m'], initial['hyper_mean_m'] + initial['hyper_sigma_m']],
|
||||||
|
'sigma_v': [10, 3000],
|
||||||
|
}
|
||||||
|
|
||||||
# Test likelihood
|
# Test likelihood
|
||||||
loglike = likelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk,
|
loglike = likelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk,
|
||||||
dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon,
|
dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon,
|
||||||
|
@ -525,9 +758,32 @@ def main():
|
||||||
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
|
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
|
||||||
|
|
||||||
|
|
||||||
|
# Run a MCMC
|
||||||
|
mcmc = run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, interp_order, bias_epsilon,
|
||||||
|
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos,)
|
||||||
|
param_labels = ['alpha', 'a_tripp', 'b_tripp', 'M_SN', 'sigma_SN', 'sigma_v',
|
||||||
|
'hyper_mean_m', 'hyper_sigma_m', 'hyper_mean_stretch', 'hyper_sigma_stretch',
|
||||||
|
'hyper_mean_c', 'hyper_sigma_c', 'hyper_corr_mx', 'hyper_corr_mc', 'hyper_corr_xc',
|
||||||
|
'vbulk_x', 'vbulk_y', 'vbulk_z']
|
||||||
|
truths = [alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v,
|
||||||
|
None, None, hyper_stretch_mu, hyper_stretch_sigma,
|
||||||
|
hyper_c_mu, hyper_c_sigma, None, None, None,
|
||||||
|
vbulk[0], vbulk[1], vbulk[2]]
|
||||||
|
true_vars = {'m':m_true, 'stretch':stretch_true, 'c': c_true}
|
||||||
|
process_mcmc_run(mcmc, param_labels, truths, true_vars)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
TO DO
|
||||||
|
|
||||||
|
- Fix SN inference - poor sampling and Tripp variables not constrained
|
||||||
|
- Deal with case where sigma_eta and sigma_m could be floats vs arrays
|
||||||
|
- Add in selection cuts for the supernovae
|
||||||
|
|
||||||
|
"""
|
BIN
tests/sn_trace.png
Normal file
After Width: | Height: | Size: 518 KiB |
BIN
tests/sn_true_predicted_c.png
Normal file
After Width: | Height: | Size: 52 KiB |
BIN
tests/sn_true_predicted_m.png
Normal file
After Width: | Height: | Size: 44 KiB |
BIN
tests/sn_true_predicted_stretch.png
Normal file
After Width: | Height: | Size: 47 KiB |
BIN
tests/tfr_corner.png
Normal file
After Width: | Height: | Size: 1.2 MiB |
|
@ -666,7 +666,7 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true,
|
||||||
plt.axhline(orig_ll, ls='--', color='k')
|
plt.axhline(orig_ll, ls='--', color='k')
|
||||||
plt.xlabel(name)
|
plt.xlabel(name)
|
||||||
plt.ylabel('Negative log-likelihood')
|
plt.ylabel('Negative log-likelihood')
|
||||||
plt.savefig(f'likelihood_scan_{name}.png')
|
plt.savefig(f'tfr_likelihood_scan_{name}.png')
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
plt.clf()
|
plt.clf()
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
@ -685,7 +685,7 @@ def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true,
|
||||||
plt.axhline(orig_ll, ls='--', color='k')
|
plt.axhline(orig_ll, ls='--', color='k')
|
||||||
plt.xlabel('mthresh')
|
plt.xlabel('mthresh')
|
||||||
plt.ylabel('Negative log-likelihood')
|
plt.ylabel('Negative log-likelihood')
|
||||||
plt.savefig(f'likelihood_scan_mthresh.png')
|
plt.savefig(f'tfr_likelihood_scan_mthresh.png')
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
plt.clf()
|
plt.clf()
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
@ -763,22 +763,20 @@ def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin,
|
||||||
vbulk_z = numpyro.sample("vbulk_z", dist.Normal(0, sigma_bulk / jnp.sqrt(3)))
|
vbulk_z = numpyro.sample("vbulk_z", dist.Normal(0, sigma_bulk / jnp.sqrt(3)))
|
||||||
|
|
||||||
# Evaluate the likelihood
|
# Evaluate the likelihood
|
||||||
numpyro.sample("obs", TFRLikelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z), obs=jnp.array([m_obs, eta_obs]))
|
numpyro.sample("obs", TFRLikelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z), obs=jnp.array([m_obs, eta_obs]))
|
||||||
|
|
||||||
|
|
||||||
class TFRLikelihood(dist.Distribution):
|
class TFRLikelihood(dist.Distribution):
|
||||||
support = dist.constraints.real
|
support = dist.constraints.real
|
||||||
|
|
||||||
def __init__(self, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z):
|
def __init__(self, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z):
|
||||||
self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, self.hyper_mean_eta, self.hyper_sigma_eta, self.m_true, self.eta_true, self.vbulk_x, self.vbulk_y, self.vbulk_z = dist.util.promote_shapes(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_mean_eta, hyper_sigma_eta, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z)
|
self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, self.m_true, self.eta_true, self.vbulk_x, self.vbulk_y, self.vbulk_z = dist.util.promote_shapes(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, vbulk_x, vbulk_y, vbulk_z)
|
||||||
batch_shape = lax.broadcast_shapes(
|
batch_shape = lax.broadcast_shapes(
|
||||||
jnp.shape(alpha),
|
jnp.shape(alpha),
|
||||||
jnp.shape(a_TFR),
|
jnp.shape(a_TFR),
|
||||||
jnp.shape(b_TFR),
|
jnp.shape(b_TFR),
|
||||||
jnp.shape(sigma_TFR),
|
jnp.shape(sigma_TFR),
|
||||||
jnp.shape(sigma_v),
|
jnp.shape(sigma_v),
|
||||||
jnp.shape(hyper_mean_eta),
|
|
||||||
jnp.shape(hyper_sigma_eta),
|
|
||||||
jnp.shape(m_true),
|
jnp.shape(m_true),
|
||||||
jnp.shape(eta_true),
|
jnp.shape(eta_true),
|
||||||
jnp.shape(vbulk_x),
|
jnp.shape(vbulk_x),
|
||||||
|
@ -851,7 +849,7 @@ def process_mcmc_run(mcmc, param_labels, truths, true_vars):
|
||||||
axs1[i].axhline(truths[i], color='k')
|
axs1[i].axhline(truths[i], color='k')
|
||||||
axs1[-1].set_xlabel('Step Number')
|
axs1[-1].set_xlabel('Step Number')
|
||||||
fig1.tight_layout()
|
fig1.tight_layout()
|
||||||
fig1.savefig('trace.png')
|
fig1.savefig('tfr_trace.png')
|
||||||
|
|
||||||
# Corner plot
|
# Corner plot
|
||||||
fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(20,20))
|
fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(20,20))
|
||||||
|
@ -861,7 +859,7 @@ def process_mcmc_run(mcmc, param_labels, truths, true_vars):
|
||||||
fig=fig2,
|
fig=fig2,
|
||||||
truths=truths
|
truths=truths
|
||||||
)
|
)
|
||||||
fig2.savefig('corner.png')
|
fig2.savefig('tfr_corner.png')
|
||||||
|
|
||||||
# True vs predicted
|
# True vs predicted
|
||||||
for var in ['eta', 'm']:
|
for var in ['eta', 'm']:
|
||||||
|
@ -889,7 +887,7 @@ def process_mcmc_run(mcmc, param_labels, truths, true_vars):
|
||||||
fig3.suptitle(var)
|
fig3.suptitle(var)
|
||||||
fig3.align_labels()
|
fig3.align_labels()
|
||||||
fig3.tight_layout()
|
fig3.tight_layout()
|
||||||
fig3.savefig(f'true_predicted_{var}.png')
|
fig3.savefig(f'tfr_true_predicted_{var}.png')
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
Before Width: | Height: | Size: 17 KiB After Width: | Height: | Size: 17 KiB |
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 18 KiB |
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 18 KiB |
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 18 KiB |
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 18 KiB |
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 18 KiB |
BIN
tests/tfr_trace.png
Normal file
After Width: | Height: | Size: 387 KiB |
BIN
tests/tfr_true_predicted_eta.png
Normal file
After Width: | Height: | Size: 46 KiB |
BIN
tests/tfr_true_predicted_m.png
Normal file
After Width: | Height: | Size: 46 KiB |