borg_velocity/tests/sn_inference.py
2025-02-10 21:30:27 +01:00

855 lines
No EOL
37 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import aquila_borg as borg
import pandas as pd
import linecache
import numpy as np
from astropy.coordinates import SkyCoord
import astropy.units as apu
import jax.numpy as jnp
import jax
import corner
import matplotlib.pyplot as plt
import borg_velocity.poisson_process as poisson_process
import borg_velocity.projection as projection
import borg_velocity.utils as utils
import numpyro
import numpyro.distributions as dist
from jax import lax, random
from tfr_inference import get_fields, generateMBData
# Output stream management
cons = borg.console()
myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x))
def create_mock(Nt, L, xmin, cpar, dens, vel, Rmax, alpha,
a_tripp, b_tripp, M_SN, sigma_SN, sigma_m, sigma_stretch, sigma_c,
hyper_stretch_mu, hyper_stretch_sigma, hyper_c_mu, hyper_c_sigma,
sigma_v, interp_order=1, bias_epsilon=1e-7):
"""
Create mock TFR catalogue from a density and velocity field
Args:
- Nt (int): Number of tracers to produce
- L (float): Box length (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters to use
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- Rmax (float): Maximum allowed comoving radius of a tracer (Mpc/h)
- alpha (float): Exponent for bias model
- a_tripp (float): Coefficient of stretch in the Tripp relation
- b_tripp (float): Coefficient of colour in the Tripp relation
- M_SN (float): Absolute magnitude of supernovae
- sigma_SN (float): Intrinsic scatter in the Tripp relation
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float): Uncertainty on the stretch measurements
- sigma_c (float): Uncertainty on the colour measurements
- hyper_stretch_mu (float): Mean of Gaussian hyper prior for the true stretch values
- hyper_stretch_sigma (float): Std of Gaussian hyper prior for the true stretch values
- hyper_c_mu (float): Mean of hyper Gaussian prior for the true colour values
- hyper_c_sigma (float): Std of Gaussian hyper prior for the true colour values
- sigma_v (float): Uncertainty on the velocity field (km/s)
- interp_order (int, default=1): Order of interpolation from grid points to the line of sight
- bias_epsilon (float, default=1e-7): Small number to add to 1 + delta to prevent 0^#
Returns:
- all_RA (np.ndarrary): Right Ascension (degrees) of the tracers (shape = (Nt,))
- all_Dec (np.ndarrary): Dec (np.ndarray): Delination (degrees) of the tracers (shape = (Nt,))
- czCMB (np.ndarrary): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- all_mtrue (np.ndarrary): True apparent magnitudes of the tracers (shape = (Nt,))
- all_mobs (np.ndarrary): Observed apparent magnitudes of the tracers (shape = (Nt,))
- all_xtrue (np.ndarrary): True comoving coordinates of the tracers (Mpc/h) (shape = (3, Nt))
- vbulk (np.ndarray): The bulk velocity of the box (km/s)
"""
# Initialize lists to store valid positions and corresponding sig_mu values
all_xtrue = np.empty((3, Nt))
all_mtrue = np.empty(Nt)
all_stretchtrue = np.empty(Nt)
all_ctrue = np.empty(Nt)
all_mobs = np.empty(Nt)
all_stretchobs = np.empty(Nt)
all_cobs = np.empty(Nt)
all_RA = np.empty(Nt)
all_Dec = np.empty(Nt)
# Counter for accepted positions
accepted_count = 0
# Bias model
phi = (1. + dens + bias_epsilon) ** alpha
# Only use centre of box
x = np.linspace(xmin, xmin + L, dens.shape[0]+1)
i0 = np.argmin(np.abs(x + Rmax))
i1 = np.argmin(np.abs(x - Rmax))
L_small = x[i1] - x[i0]
xmin_small = x[i0]
phi_small = phi[i0:i1, i0:i1, i0:i1]
# Loop until we have Nt valid positions
while accepted_count < Nt:
# Generate positions (comoving)
xtrue = poisson_process.sample_3d(phi_small, Nt, L_small, (xmin_small, xmin_small, xmin_small))
# Convert to RA, Dec, Distance (comoving)
rtrue = np.sqrt(np.sum(xtrue** 2, axis=0)) # Mpc/h
c = SkyCoord(x=xtrue[0], y=xtrue[1], z=xtrue[2], representation_type='cartesian')
RA = c.spherical.lon.degree
Dec = c.spherical.lat.degree
r_hat = np.array(SkyCoord(ra=RA*apu.deg, dec=Dec*apu.deg).cartesian.xyz)
# Compute cosmological redshift
zcosmo = utils.z_cos(rtrue, cpar.omega_m)
# Compute luminosity distance
# DO I NEED TO DO /h???
dL = (1 + zcosmo) * rtrue / cpar.h # Mpc
# Compute true distance modulus
mutrue = 5 * np.log10(dL) + 25
# Sample true stretch and colour (c) from its prior
stretchtrue = hyper_stretch_mu + hyper_stretch_sigma * np.random.randn(Nt)
ctrue = hyper_c_mu + hyper_c_sigma * np.random.randn(Nt)
# Obtain muSN from mutrue using the intrinsic scatter
muSN = mutrue + sigma_SN * np.random.randn(Nt)
# Obtain apparent magnitude from the TFR
mtrue = muSN - (a_tripp * stretchtrue - b_tripp * ctrue) + M_SN
# Scatter true observed apparent magnitudes and linewidths
mobs = mtrue + sigma_m * np.random.randn(Nt)
stretchobs = stretchtrue + sigma_stretch * np.random.randn(Nt)
cobs = ctrue + sigma_c * np.random.randn(Nt)
# Apply apparement magnitude cut
m = np.ones(mobs.shape, dtype=bool)
mtrue = mtrue[m]
stretchtrue = stretchtrue[m]
ctrue = ctrue[m]
mobs = mobs[m]
stretchobs = stretchobs[m]
cobs = cobs[m]
xtrue = xtrue[:,m]
RA = RA[m]
Dec = Dec[m]
# Calculate how many valid positions we need to reach Nt
remaining_needed = Nt - accepted_count
selected_count = min(xtrue.shape[1], remaining_needed)
# Append only the needed number of valid positions
imin = accepted_count
imax = accepted_count + selected_count
all_xtrue[:,imin:imax] = xtrue[:,:selected_count]
all_mtrue[imin:imax] = mtrue[:selected_count]
all_stretchtrue[imin:imax] = stretchtrue[:selected_count]
all_ctrue[imin:imax] = ctrue[:selected_count]
all_mobs[imin:imax] = mobs[:selected_count]
all_stretchobs[imin:imax] = stretchobs[:selected_count]
all_cobs[imin:imax] = cobs[:selected_count]
all_RA[imin:imax] = RA[:selected_count]
all_Dec[imin:imax] = Dec[:selected_count]
# Update the accepted count
accepted_count += selected_count
myprint(f'\tMade {accepted_count} of {Nt}')
# Obtain a bulk velocity
vhat = np.random.randn(3)
vhat = vhat / np.linalg.norm(vhat)
vbulk = np.random.randn() * utils.get_sigma_bulk(L, cpar)
vbulk = vhat * vbulk
# Get the radial component of the peculiar velocity at the positions of the objects
myprint('Obtaining peculiar velocities')
tracer_vel = projection.interp_field(
vel,
np.expand_dims(all_xtrue, axis=2),
L,
np.array([xmin, xmin, xmin]),
interp_order
) # km/s
myprint('Adding bulk velocity')
tracer_vel = tracer_vel + vbulk[:,None,None]
myprint('Radial projection')
vr_true = np.squeeze(projection.project_radial(
tracer_vel,
np.expand_dims(all_xtrue, axis=2),
np.zeros(3,)
)) # km/s
# Recompute cosmological redshift
rtrue = jnp.sqrt(jnp.sum(all_xtrue ** 2, axis=0))
zcosmo = utils.z_cos(rtrue, cpar.omega_m)
# Obtain total redshift
vr_noised = vr_true + sigma_v * np.random.randn(Nt)
czCMB = ((1 + zcosmo) * (1 + vr_noised / utils.speed_of_light) - 1) * utils.speed_of_light
return all_RA, all_Dec, czCMB, all_mtrue, all_stretchtrue, all_ctrue, all_mobs, all_stretchobs, all_cobs, all_xtrue, vbulk
def estimate_data_parameters():
"""
Using Foundation DR1, estimate some parameters to use in mock generation.
Returns:
- sigma_m (float): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float): Uncertainty on the stretch measurements
- sigma_c (float): Uncertainty on the colour measurements
- hyper_stretch_mu (float): Estimate of the mean of Gaussian hyper prior for the true stretch values
- hyper_stretch_sigma (float): Estimate of the std of Gaussian hyper prior for the true stretch values
- hyper_c_mu (float): Estimate of the mean of hyper Gaussian prior for the true colour values
- hyper_c_sigma (float): Estimate of the std of Gaussian hyper prior for the true colour values
"""
fname = '/data101/bartlett/fsigma8/PV_data/Foundation_DR1/Foundation_DR1.FITRES.TEXT'
# Get header
columns = ['SN'] + linecache.getline(fname, 6).strip().split()[1:]
df = pd.read_csv(fname, sep="\s+", skipinitialspace=True, skiprows=7, names=columns)
zCMB = df['zCMB']
m = df['mB']
m_err = df['mBERR']
x1 = df['x1']
hyper_stretch_mu = np.median(x1)
hyper_stretch_sigma = (np.percentile(x1, 84) - np.percentile(x1, 16)) / 2
c = df['c']
hyper_c_mu = np.median(c)
hyper_c_sigma = (np.percentile(c, 84) - np.percentile(c, 16)) / 2
sigma_m = np.median(df['mBERR'])
sigma_stretch = np.median(df['x1ERR'])
sigma_c = np.median(df['cERR'])
return sigma_m, sigma_stretch, sigma_c, hyper_stretch_mu, hyper_stretch_sigma, hyper_c_mu, hyper_c_sigma
def likelihood_vel(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk,
dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, MB_pos):
"""
Evaluate the terms in the likelihood from the velocity and malmquist bias
Args:
- alpha (float): Exponent for bias model
- a_tripp (float): Coefficient of stretch in the Tripp relation
- b_tripp (float): Coefficient of colour in the Tripp relation
- M_SN (float): Absolute magnitude of supernovae
- sigma_SN (float): Intrinsic scatter in the Tripp relation
- sigma_v (float): Uncertainty on the velocity field (km/s)
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,))
- c_true (np.ndarray): True colour values of the tracers (shape = (Nt,))
- vbulk (np.ndarray): Bulk velocity of the box (km/s) (shape=(3,))
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- omega_m (float): Matter density parameter Om
- h (float): Hubble constant H0 = 100 h km/s/Mpc
- L (float): Comoving box size (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- interp_order (int): Order of interpolation from grid points to the line of sight
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
Returns:
- loglike (float): The log-likelihood of the data
"""
# Comoving radii of integration points (Mpc/h)
r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0))
# p_r = r^2 n(r) N(mutrue; muTFR, sigmaTFR)
# Multiply by arbitrary number for numerical stability (cancels in p_r / p_r_norm)
number_density = projection.interp_field(
dens,
MB_pos,
L,
jnp.array([xmin, xmin, xmin]),
interp_order,
use_jitted=True,
)
number_density = jax.nn.relu(1. + number_density)
number_density = jnp.power(number_density + bias_epsilon, alpha)
zcosmo = utils.z_cos(r, omega_m)
mutrue = 5 * jnp.log10((1 + zcosmo) * r / h) + 25
mutripp = m_true + a_tripp * stretch_true - b_tripp * c_true - M_SN
d2 = ((mutrue - mutripp[:,None]) / sigma_SN) ** 2
best = jnp.amin(jnp.abs(d2), axis=1)
d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1)
p_r = r ** 2 * jnp.exp(-0.5 * d2) * number_density
p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1)
# Peculiar velocity term
tracer_vel = projection.interp_field(
vel,
MB_pos,
L,
jnp.array([xmin, xmin, xmin]),
interp_order,
use_jitted=True,
)
tracer_vel = tracer_vel + jnp.squeeze(vbulk)[...,None,None]
tracer_vr = projection.project_radial(
tracer_vel,
MB_pos,
jnp.zeros(3,)
)
cz_pred = ((1 + zcosmo) * (1 + tracer_vr / utils.speed_of_light) - 1) * utils.speed_of_light
d2 = ((cz_pred - jnp.expand_dims(cz_obs, axis=1)) / sigma_v)**2
scale = jnp.nanmin(d2, axis=1)
d2 = d2 - jnp.expand_dims(scale, axis=1)
# Integrate to get likelihood
p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1)
lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * np.pi * sigma_v**2)
loglike = lkl_ind.sum()
return loglike
def likelihood_stretch(stretch_true, stretch_obs, sigma_stretch):
"""
Evaluate the terms in the likelihood from stretch
Args:
- stretch_true (np.ndarray): True stretch of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch of the tracers (shape = (Nt,))
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
Nt = stretch_obs.shape[0]
norm = jnp.where(
jnp.ndim(sigma_stretch) == 0,
Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2),
jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2))
)
loglike = - (
0.5 * jnp.sum((stretch_obs - stretch_true) ** 2 / sigma_stretch ** 2)
+ norm
)
return loglike
def likelihood_c(c_true, c_obs, sigma_c):
"""
Evaluate the terms in the likelihood from colour
Args:
- c_true (np.ndarray): True colours of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colours of the tracers (shape = (Nt,))
- sigma_c (float or np.ndarray): Uncertainty on the colours measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
Nt = c_obs.shape[0]
norm = jnp.where(
jnp.ndim(sigma_c) == 0,
Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2),
jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2))
)
loglike = - (
0.5 * jnp.sum((c_obs - c_true) ** 2 / sigma_c ** 2)
+ norm
)
return loglike
def likelihood_m(m_true, m_obs, sigma_m):
"""
Evaluate the terms in the likelihood from apparent magntiude
Args:
- m_true (np.ndarray): True apparent magnitude of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitude of the tracers (shape = (Nt,))
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
Nt = m_obs.shape[0]
norm = jnp.where(
jnp.ndim(sigma_m) == 0,
Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2),
jnp.sum(0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2))
)
loglike = - (
0.5 * jnp.sum((m_obs - m_true) ** 2 / sigma_m ** 2)
+ norm
)
return loglike
def likelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk,
dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos):
"""
Evaluate the likelihood for SN sample
Args:
- alpha (float): Exponent for bias model
- a_tripp (float): Coefficient of stretch in the Tripp relation
- b_tripp (float): Coefficient of colour in the Tripp relation
- M_SN (float): Absolute magnitude of supernovae
- sigma_SN (float): Intrinsic scatter in the Tripp relation
- sigma_v (float): Uncertainty on the velocity field (km/s)
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,))
- c_true (np.ndarray): True colour values of the tracers (shape = (Nt,))
- vbulk (np.ndarray): Bulk velocity of the box (km/s) (shape=(3,))
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- omega_m (float): Matter density parameter Om
- h (float): Hubble constant H0 = 100 h km/s/Mpc
- L (float): Comoving box size (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- interp_order (int): Order of interpolation from grid points to the line of sight
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,))
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
- sigma_c (float or np.ndarray): Uncertainty on the colour measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
Returns:
- loglike (float): The log-likelihood of the data
"""
loglike_vel = likelihood_vel(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk,
dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, MB_pos)
loglike_stretch = likelihood_stretch(stretch_true, stretch_obs, sigma_stretch)
loglike_c = likelihood_c(c_true, c_obs, sigma_c)
loglike_m = likelihood_m(m_true, m_obs, sigma_m)
loglike = loglike_vel + loglike_stretch + loglike_c + loglike_m
return loglike
def test_likelihood_scan(prior, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk,
dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon,
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos):
"""
Plot likelihood as we scan through the paramaters [alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v]
to verify that the likelihood shape looks reasonable
Args:
- prior (dict): Upper and lower bounds for a uniform prior for the parameters
- alpha (float): Exponent for bias model
- a_tripp (float): Coefficient of stretch in the Tripp relation
- b_tripp (float): Coefficient of colour in the Tripp relation
- M_SN (float): Absolute magnitude of supernovae
- sigma_SN (float): Intrinsic scatter in the Tripp relation
- sigma_v (float): Uncertainty on the velocity field (km/s)
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,))
- c_true (np.ndarray): True colour values of the tracers (shape = (Nt,))
- vbulk (np.ndarray): Bulk velocity of the box (km/s) (shape=(3,))
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- omega_m (float): Matter density parameter Om
- h (float): Hubble constant H0 = 100 h km/s/Mpc
- L (float): Comoving box size (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- interp_order (int): Order of interpolation from grid points to the line of sight
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
- czCMB (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,))
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
- sigma_c (float or np.ndarray): Uncertainty on the colour measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
"""
pars = [alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v]
par_names = ['alpha', 'a_tripp', 'b_tripp', 'M_SN', 'sigma_SN', 'sigma_v']
orig_ll = - likelihood(*pars, m_true, stretch_true, c_true, vbulk,
dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon,
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
for i, name in enumerate(par_names):
myprint(f'Scanning {name}')
if name in prior:
x = np.linspace(*prior[name], 20)
else:
pmin = pars[i] * 0.2
pmax = pars[i] * 2.0
x = np.linspace(pmin, pmax, 20)
all_ll = np.empty(x.shape)
orig_x = pars[i]
for j, xx in enumerate(x):
pars[i] = xx
all_ll[j] = - likelihood(*pars, m_true, stretch_true, c_true, vbulk,
dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon,
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
pars[i] = orig_x
plt.figure()
plt.plot(x, all_ll, '.')
plt.axvline(orig_x, ls='--', color='k')
plt.axhline(orig_ll, ls='--', color='k')
plt.xlabel(name)
plt.ylabel('Negative log-likelihood')
plt.savefig(f'sn_likelihood_scan_{name}.png')
fig = plt.gcf()
plt.clf()
plt.close(fig)
return
def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, interp_order, bias_epsilon,
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos,):
"""
Run MCMC over the model parameters
Args:
- num_warmup (int): Number of warmup steps to take in the MCMC
- num_samples (int): Number of samples to take in the MCMC
- prior (dict): Upper and lower bounds for a uniform prior for the parameters
- initial (dict): Initial values for the MCMC
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters to use
- L (float): Comoving box size (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- interp_order (int): Order of interpolation from grid points to the line of sight
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
- czCMB (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,))
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
- sigma_c (float or np.ndarray): Uncertainty on the colour measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
Returns:
- mcmc (numpyro.infer.MCMC): MCMC object which has been run
"""
Nt = stretch_obs.shape[0]
omega_m = cpar.omega_m
h = cpar.h
sigma_bulk = utils.get_sigma_bulk(L, cpar)
def sn_model():
alpha = numpyro.sample("alpha", dist.Uniform(*prior['alpha']))
a_tripp = numpyro.sample("a_tripp", dist.Uniform(*prior['a_tripp']))
b_tripp = numpyro.sample("b_tripp", dist.Uniform(*prior['b_tripp']))
M_SN = numpyro.sample("M_SN", dist.Uniform(*prior['M_SN']))
sigma_SN = numpyro.sample("sigma_SN", dist.HalfCauchy(1.0))
sigma_v = numpyro.sample("sigma_v", dist.Uniform(*prior['sigma_v']))
hyper_mean_m = numpyro.sample("hyper_mean_m", dist.Uniform(*prior['hyper_mean_m']))
hyper_sigma_m = numpyro.sample("hyper_sigma_m", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior
hyper_mean_stretch = numpyro.sample("hyper_mean_stretch", dist.Uniform(*prior['hyper_mean_stretch']))
hyper_sigma_stretch = numpyro.sample("hyper_sigma_stretch", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior
hyper_mean_c = numpyro.sample("hyper_mean_c", dist.Uniform(*prior['hyper_mean_c']))
hyper_sigma_c = numpyro.sample("hyper_sigma_c", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior
# Sample correlation matrix using LKJ prior
L_corr = numpyro.sample("L_corr", dist.LKJCholesky(3, concentration=1.0)) # Cholesky factor of correlation matrix
corr_matrix = L_corr @ L_corr.T # Convert to full correlation matrix
# Construct full covariance matrix: Σ = D * Corr * D
hyper_mean = jnp.array([hyper_mean_m, hyper_mean_stretch, hyper_mean_c])
hyper_sigma = jnp.array([hyper_sigma_m, hyper_sigma_stretch, hyper_sigma_c])
hyper_cov = jnp.diag(hyper_sigma) @ corr_matrix @ jnp.diag(hyper_sigma)
# Sample m_true and eta_true
x = numpyro.sample("true_vars", dist.MultivariateNormal(hyper_mean, hyper_cov), sample_shape=(Nt,))
m_true = numpyro.deterministic("m_true", x[:, 0])
stretch_true = numpyro.deterministic("stretch_true", x[:, 1])
c_true = numpyro.deterministic("c_true", x[:, 2])
# Sample bulk velocity
vbulk_x = numpyro.sample("vbulk_x", dist.Normal(0, sigma_bulk / jnp.sqrt(3)))
vbulk_y = numpyro.sample("vbulk_y", dist.Normal(0, sigma_bulk / jnp.sqrt(3)))
vbulk_z = numpyro.sample("vbulk_z", dist.Normal(0, sigma_bulk / jnp.sqrt(3)))
# Evaluate the likelihood
numpyro.sample("obs", SNLikelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z), obs=jnp.array([m_obs, stretch_obs, c_obs]))
class SNLikelihood(dist.Distribution):
support = dist.constraints.real
def __init__(self, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z):
self.alpha, self.a_tripp, self.b_tripp, self.M_SN, self.sigma_SN, self.sigma_v, self.m_true, self.stretch_true, self.c_true, self.vbulk_x, self.vbulk_y, self.vbulk_z = dist.util.promote_shapes(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk_x, vbulk_y, vbulk_z)
batch_shape = lax.broadcast_shapes(
jnp.shape(alpha),
jnp.shape(a_tripp),
jnp.shape(b_tripp),
jnp.shape(M_SN),
jnp.shape(sigma_SN),
jnp.shape(sigma_v),
jnp.shape(m_true),
jnp.shape(stretch_true),
jnp.shape(c_true),
jnp.shape(vbulk_x),
jnp.shape(vbulk_y),
jnp.shape(vbulk_z),
)
super(SNLikelihood, self).__init__(batch_shape = batch_shape)
def sample(self, key, sample_shape=()):
raise NotImplementedError
def log_prob(self, value):
vbulk = jnp.array([self.vbulk_x, self.vbulk_y, self.vbulk_z])
loglike = likelihood(self.alpha, self.a_tripp, self.b_tripp, self.M_SN, self.sigma_SN, self.sigma_v,
self.m_true, self.stretch_true, self.c_true, vbulk,
dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon,
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
return loglike
rng_key = random.PRNGKey(6)
rng_key, rng_key_ = random.split(rng_key)
values = initial
values['true_vars'] = jnp.array([m_obs, stretch_obs, c_obs]).T
values['L_corr'] = jnp.identity(3)
values['vbulk_x'] = 0.
values['vbulk_y'] = 0.
values['vbulk_z'] = 0.
myprint('Preparing MCMC kernel')
kernel = numpyro.infer.NUTS(sn_model,
init_strategy=numpyro.infer.initialization.init_to_value(values=initial)
)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
myprint('Running MCMC')
mcmc.run(rng_key_)
mcmc.print_summary()
return mcmc
def process_mcmc_run(mcmc, param_labels, truths, true_vars):
"""
Make summary plots from the MCMC and save these to file
Args:
- mcmc (numpyro.infer.MCMC): MCMC object which has been run
- param_labels (list[str]): Names of the parameters to plot
- truths (list[float]): True values of the parameters to plot. If unknown, then entry is None
- true_vars (dict): True values of the observables to compare against inferred ones
"""
# Convert samples into a single array
samples = mcmc.get_samples()
samps = jnp.empty((len(samples[param_labels[0]]), len(param_labels)))
for i, p in enumerate(param_labels):
if p.startswith('hyper_corr'):
L_corr = samples['L_corr']
corr_matrix = jnp.matmul(L_corr, jnp.transpose(L_corr, (0, 2, 1)))
if p == 'hyper_corr_mx':
samps = samps.at[:,i].set(corr_matrix[:,0,1])
elif p == 'hyper_corr_mc':
samps = samps.at[:,i].set(corr_matrix[:,0,2])
elif p == 'hyper_corr_xc':
samps = samps.at[:,i].set(corr_matrix[:,1,2])
else:
raise NotImplementedError
else:
samps = samps.at[:,i].set(samples[p])
# Trace plot of non-redshift quantities
fig1, axs1 = plt.subplots(samps.shape[1], 1, figsize=(6,3*samps.shape[1]), sharex=True)
axs1 = np.atleast_1d(axs1)
for i in range(samps.shape[1]):
axs1[i].plot(samps[:,i])
axs1[i].set_ylabel(param_labels[i])
if truths[i] is not None:
axs1[i].axhline(truths[i], color='k')
axs1[-1].set_xlabel('Step Number')
fig1.tight_layout()
fig1.savefig('sn_trace.png')
# Corner plot
fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(25,25))
corner.corner(
np.array(samps),
labels=param_labels,
fig=fig2,
truths=truths
)
fig2.savefig('sn_corner.png')
# True vs predicted
for var in ['stretch', 'c', 'm']:
vname = var + '_true'
if vname in samples.keys():
xtrue = true_vars[var]
xpred_median = np.median(samples[vname], axis=0)
xpred_plus = np.percentile(samples[vname], 84, axis=0) - xpred_median
xpred_minus = xpred_median - np.percentile(samples[vname], 16, axis=0)
fig3, axs3 = plt.subplots(2, 1, figsize=(10,8), sharex=True)
plot_kwargs = {'fmt':'.', 'markersize':3, 'zorder':10,
'capsize':1, 'elinewidth':1, 'alpha':1}
axs3[0].errorbar(xtrue, xpred_median, yerr=[xpred_minus, xpred_plus], **plot_kwargs)
axs3[1].errorbar(xtrue, xpred_median - xtrue, yerr=[xpred_minus, xpred_plus], **plot_kwargs)
axs3[1].set_xlabel('True')
axs3[0].set_ylabel('Predicted')
axs3[1].set_ylabel('Predicted - True')
xlim = axs3[0].get_xlim()
ylim = axs3[0].get_ylim()
axs3[0].plot(xlim, xlim, color='k', zorder=0)
axs3[0].set_xlim(xlim)
axs3[0].set_ylim(ylim)
axs3[1].axhline(0, color='k', zorder=0)
fig3.suptitle(var)
fig3.align_labels()
fig3.tight_layout()
fig3.savefig(f'sn_true_predicted_{var}.png')
return
def main():
myprint('Beginning')
sigma_m, sigma_stretch, sigma_c, hyper_stretch_mu, hyper_stretch_sigma, hyper_c_mu, hyper_c_sigma = estimate_data_parameters()
# Other parameters to use
L = 500.0
N = 64
xmin = -L/2
R_lim = L / 2
Rmax = 100
Nt = 100
alpha = 1.4
sigma_v = 150
interp_order = 1
bias_epsilon = 1.e-7
Nint_points = 201
Nsig = 10
frac_sigma_r = 0.07 # WANT A BETTER WAY OF DOING THIS - ESTIMATE THROUGH SIGMAS FROM Tripp formula
# These values are from Table 6 of Boruah et al. 2020
a_tripp = 0.140
b_tripp = 2.78
M_SN = - 18.558
sigma_SN = 0.082
num_warmup = 1000
num_samples = 2000
# Make mock
np.random.seed(123)
cpar, dens, vel = get_fields(L, N, xmin)
RA, Dec, czCMB, m_true, stretch_true, c_true, m_obs, stretch_obs, c_obs, xtrue, vbulk = create_mock(
Nt, L, xmin, cpar, dens, vel, Rmax, alpha,
a_tripp, b_tripp, M_SN, sigma_SN, sigma_m, sigma_stretch, sigma_c,
hyper_stretch_mu, hyper_stretch_sigma, hyper_c_mu, hyper_c_sigma,
sigma_v, interp_order=interp_order, bias_epsilon=bias_epsilon)
MB_pos = generateMBData(RA, Dec, czCMB, L, N, R_lim, Nsig, Nint_points, sigma_v, frac_sigma_r)
initial = {
'a_tripp': a_tripp,
'b_tripp': b_tripp,
'M_SN': M_SN,
'sigma_SN': sigma_SN,
'sigma_v': sigma_v,
'hyper_mean_stretch': hyper_stretch_mu,
'hyper_sigma_stretch': hyper_stretch_sigma,
'hyper_mean_c': hyper_c_mu,
'hyper_sigma_c': hyper_c_sigma,
'hyper_mean_m': np.median(m_obs),
'hyper_sigma_m': (np.percentile(m_obs, 84) - np.percentile(m_obs, 16)) / 2,
}
prior = {
'alpha': [0.5, 6],
'a_tripp': [0.01, 0.2],
'b_tripp': [2.5, 4.5],
'M_SN': [-19.5, -17.5],
'hyper_mean_stretch': [hyper_stretch_mu - hyper_stretch_sigma, hyper_stretch_mu + hyper_stretch_sigma],
'hyper_mean_c':[hyper_c_mu - hyper_c_sigma, hyper_c_mu + hyper_c_sigma],
'hyper_mean_m':[initial['hyper_mean_m'] - initial['hyper_sigma_m'], initial['hyper_mean_m'] + initial['hyper_sigma_m']],
'sigma_v': [10, 3000],
}
# Test likelihood
loglike = likelihood(alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk,
dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon,
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
myprint(f'loglike {loglike}')
# Scan over parameters to make plots verifying behaviour
test_likelihood_scan(prior, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, vbulk,
dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon,
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
# Run a MCMC
mcmc = run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar, L, xmin, interp_order, bias_epsilon,
czCMB, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos)
param_labels = ['alpha', 'a_tripp', 'b_tripp', 'M_SN', 'sigma_SN', 'sigma_v',
'hyper_mean_m', 'hyper_sigma_m',
'hyper_mean_stretch', 'hyper_sigma_stretch', 'hyper_mean_c', 'hyper_sigma_c',
'hyper_corr_mx', 'hyper_corr_mc', 'hyper_corr_xc',
'vbulk_x', 'vbulk_y', 'vbulk_z']
truths = [alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v,
None, None,
hyper_stretch_mu, hyper_stretch_sigma, hyper_c_mu, hyper_c_sigma,
None, None, None,
vbulk[0], vbulk[1], vbulk[2]]
true_vars = {'m':m_true, 'stretch':stretch_true, 'c': c_true}
process_mcmc_run(mcmc, param_labels, truths, true_vars)
return
if __name__ == "__main__":
main()