Successfully call loglike from BORG

This commit is contained in:
Deaglan Bartlett 2025-02-11 12:00:59 +01:00
parent 8d74403002
commit b4bf734782
3 changed files with 539 additions and 22 deletions

View file

@ -376,6 +376,11 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None):
c_true = [None] * nsamp
c_obs = [None] * nsamp
xtrue = [None] * nsamp
all_sigma_m = [None] * nsamp
all_sigma_eta = [None] * nsamp
all_sigma_stretch = [None] * nsamp
all_sigma_c = [None] * nsamp
all_mthresh = [None] * nsamp
for i in range(nsamp):
@ -397,12 +402,17 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None):
hyper_eta_mu = float(config[f'sample_{i}']['hyper_eta_mu'])
hyper_eta_sigma = float(config[f'sample_{i}']['hyper_eta_sigma'])
mthresh = float(config[f'sample_{i}']['mthresh'])
RA[i], Dec[i], czCMB[i], m_true[i], eta_true[i], m_obs[i], eta_obs[i], xtrue[i] = tfr_create_mock(
Nt, L, xmin, cosmo, dens, vel + vbulk,
Rmax, alpha, mthresh,
a_TFR, b_TFR, sigma_TFR, sigma_m, sigma_eta,
hyper_eta_mu, hyper_eta_sigma, sigma_v,
interp_order=interp_order, bias_epsilon=bias_epsilon)
all_sigma_m[i] = np.full(Nt, sigma_m)
all_sigma_eta[i] = np.full(Nt, sigma_eta)
all_mthresh[i] = mthresh
elif tracer_type[i] == 'sn':
a_tripp = float(config[f'sample_{i}']['a_tripp'])
b_tripp = float(config[f'sample_{i}']['b_tripp'])
@ -415,12 +425,17 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None):
hyper_stretch_sigma = float(config[f'sample_{i}']['hyper_stretch_sigma'])
hyper_c_mu = float(config[f'sample_{i}']['hyper_c_mu'])
hyper_c_sigma = float(config[f'sample_{i}']['hyper_c_sigma'])
RA[i], Dec[i], czCMB[i], m_true[i], stretch_true[i], c_true[i], m_obs[i], stretch_obs[i], c_obs[i], xtrue[i] = sn_create_mock(
Nt, L, xmin, cosmo, dens, vel + vbulk, Rmax, alpha,
a_tripp, b_tripp, M_SN, sigma_SN, sigma_m, sigma_stretch, sigma_c,
hyper_stretch_mu, hyper_stretch_sigma, hyper_c_mu, hyper_c_sigma,
sigma_v, interp_order=1, bias_epsilon=1e-7)
all_sigma_m[i] = np.full(Nt, sigma_m)
all_sigma_stretch[i] = np.full(Nt, sigma_stretch)
all_sigma_c[i] = np.full(Nt, sigma_c)
else:
raise NotImplementedError
return tracer_type, RA, Dec, czCMB, m_true, m_obs, eta_true, eta_obs, stretch_true, stretch_obs, c_true, c_obs, xtrue
return tracer_type, RA, Dec, czCMB, m_true, m_obs, eta_true, eta_obs, stretch_true, stretch_obs, c_true, c_obs, xtrue, all_sigma_m, all_sigma_eta, all_sigma_stretch, all_sigma_c, all_mthresh

View file

@ -1,18 +1,21 @@
import aquila_borg as borg
import jax
import numpy as np
import jaxlib
import jax.numpy as jnp
import numpy as np
import configparser
import h5py
import warnings
import ast
import re
from functools import partial
import numbers
import borg_velocity.utils as utils
from borg_velocity.utils import myprint, generateMBData, compute_As, get_sigma_bulk
from borg_velocity.borg_mock import borg_mock
import borg_velocity.forwards as forwards
import borg_velocity.tracer_likelihoods as tracer_likelihoods
from borg_velocity.samplers import HMCBiasSampler, derive_prior, MVSliceBiasSampler, BlackJaxBiasSampler, TransformedBlackJaxBiasSampler
class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
@ -40,9 +43,12 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
# Initialise grid parameters
self.N = [int(config['system'][f'N{i}']) for i in range(3)]
self.L = [float(config['system'][f'L{i}']) for i in range(3)]
self.corner = [float(config['system'][f'corner{i}']) for i in range(3)]
# For log-likelihood values
self.bignum = float(config['mcmc']['bignum'])
self.interp_order = int(config['model']['interp_order'])
self.bias_epsilon = float(config['model']['bias_epsilon'])
myprint(f" Init {self.N}, {self.L}")
super().__init__(fwd, self.N, self.L)
@ -52,16 +58,59 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
self.fwd_param = param_model
self.fwd_vel = fwd_vel
# Number of samples of distance tracers
self.nsamp = int(config['run']['nsamp'])
self.tracer_type = [config[f'sample_{i}']['tracer_type'] for i in range(self.nsamp)]
self.Nt = [int(config[f'sample_{i}']['Nt']) for i in range(self.nsamp)]
# Initialise cosmological parameters
cpar = utils.get_cosmopar(self.ini_file)
self.fwd.setCosmoParams(cpar)
self.fwd_param.setCosmoParams(cpar)
self.updateCosmology(cpar)
myprint(f"Original cosmological parameters: {self.fwd.getCosmoParams()}")
# Initialise model parameters
self.sig_v = float(config['model']['sig_v'])
self.bulk_flow = np.array(ast.literal_eval(config['model']['bulk_flow']))
bulk_flow = np.array(ast.literal_eval(config['model']['bulk_flow']))
self.model_params = {
'sig_v': float(config['model']['sig_v']),
'bulk_flow_x': bulk_flow[0],
'bulk_flow_y': bulk_flow[1],
'bulk_flow_z': bulk_flow[2],
}
self.vector_model_params = {}
for i in range(self.nsamp):
tracer_type = config[f'sample_{i}']['tracer_type']
Nt = int(config[f'sample_{i}']['Nt'])
self.model_params[f'alpha_{i}'] = float(config[f'sample_{i}']['alpha'])
self.vector_model_params[f'm_true_{i}'] = np.zeros(Nt)
self.model_params[f'hyper_mean_m_{i}'] = 0.
self.model_params[f'hyper_sigma_m_{i}'] = 1.
if tracer_type == 'tfr':
self.vector_model_params[f'eta_true_{i}'] = np.zeros(Nt)
self.model_params[f'a_TFR_{i}'] = float(config[f'sample_{i}']['a_tfr'])
self.model_params[f'b_TFR_{i}'] = float(config[f'sample_{i}']['b_tfr'])
self.model_params[f'sigma_TFR_{i}'] = float(config[f'sample_{i}']['sigma_tfr'])
self.model_params[f'hyper_mean_eta_{i}'] = 0.
self.model_params[f'hyper_sigma_eta_{i}'] = 1.
self.model_params[f'hyper_corr_meta_{i}'] = 0.
elif tracer_type == 'sn':
self.vector_model_params[f'stretch_true_{i}'] = np.zeros(Nt)
self.vector_model_params[f'c_true_{i}'] = np.zeros(Nt)
self.model_params[f'a_tripp_{i}'] = float(config[f'sample_{i}']['a_tripp'])
self.model_params[f'b_tripp_{i}'] = float(config[f'sample_{i}']['b_tripp'])
self.model_params[f'M_SN_{i}'] = float(config[f'sample_{i}']['m_sn'])
self.model_params[f'sigma_SN_{i}'] = float(config[f'sample_{i}']['sigma_sn'])
self.model_params[f'hyper_mean_stretch_{i}'] = 0.
self.model_params[f'hyper_sigma_stretch_{i}'] = 1.
self.model_params[f'hyper_mean_c_{i}'] = 0.
self.model_params[f'hyper_sigma_c_{i}'] = 1.
self.model_params[f'hyper_corr_mx_{i}'] = 0.
self.model_params[f'hyper_corr_mc_{i}'] = 0.
self.model_params[f'hyper_corr_xc_{i}'] = 0.
else:
raise ValueError
self.fwd_param.setModelParams(self.model_params)
# Initialise derivative
self.grad_like = jax.grad(self.dens2like, argnums=(0,1))
@ -78,9 +127,6 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
# Seed if creating mocks
self.mock_seed = int(config['mock']['seed'])
# Number of samples to distance tracers
self.nsamp = int(config['run']['nsamp'])
# Initialise integration variables
if config['model']['R_lim'] == 'none':
self.R_lim = self.L[0]/2
@ -103,6 +149,17 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
myprint("Init likelihood")
state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True)
for i in range(self.nsamp):
if self.tracer_type[i] == 'tfr':
state.newArray1d(f"m_true_{i}", self.Nt[i], True)
state.newArray1d(f"eta_true_{i}", self.Nt[i], True)
elif self.tracer_type[i] == 'sn':
state.newArray1d(f"m_true_{i}", self.Nt[i], True)
state.newArray1d(f"stretch_true_{i}", self.Nt[i], True)
state.newArray1d(f"c_true_{i}", self.Nt[i], True)
else:
raise ValueError(f"Unknown tracer type: {self.tracer_type[i]}")
if self.run_type == 'data':
self.loadObservedData()
elif borg.EMBEDDED:
@ -162,21 +219,67 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
- make_plot (bool, default=True): Whether to make diagnostic plots for the mock data generation
"""
self.data = {}
self.truths = {}
if self.run_type == 'data':
raise NotImplementedError
elif self.run_type == 'velmass':
raise NotImplementedError
elif self.run_type == 'mock':
self.tracer_type, self.RA, self.Dec, self.cz_obs, self.m_true, self.m_obs, self.eta_true, self.eta_obs, self.stretch_true, self.stretch_obs, self.c_true, self.c_obs, self.xtrue = borg_mock(s_hat, state, self.fwd, self.fwd_vel, self.ini_file, seed=self.mock_seed)
res = borg_mock(s_hat, state, self.fwd, self.fwd_vel, self.ini_file, seed=self.mock_seed)
self.data['tracer_type'] = res[0]
self.data['RA'] = res[1]
self.data['Dec'] = res[2]
self.data['cz_obs'] = res[3]
self.truths['m_true'] = res[4]
self.data['m_obs'] = res[5]
self.truths['eta_true'] = res[6]
self.data['eta_obs'] = res[7]
self.truths['stretch_true'] = res[8]
self.data['stretch_obs'] = res[9]
self.truths['c_true'] = res[10]
self.data['c_obs'] = res[11]
self.truths['xtrue'] = res[12]
self.data['sigma_m'] = res[13]
self.data['sigma_eta'] = res[14]
self.data['sigma_stretch'] = res[15]
self.data['sigma_c'] = res[16]
self.data['mthresh'] = res[17]
else:
raise NotImplementedError
# Get integration points
for i in range(self.nsamp):
self.MB_pos[i] = utils.generateMBData(
self.RA[i], self.Dec[i], self.cz_obs[i], self.L[0], self.N[0],
self.R_lim, self.Nsig, self.Nint_points, self.sig_v, self.frac_sigma_r[i])
self.data['RA'][i], self.data['Dec'][i], self.data['cz_obs'][i], self.L[0], self.N[0],
self.R_lim, self.Nsig, self.Nint_points, self.model_params['sig_v'], self.frac_sigma_r[i])
# Initialise model parameters
for i in range(self.nsamp):
state[f'm_true_{i}'][:] = self.data['m_obs'][i]
if self.data['tracer_type'][i] == 'tfr':
state[f'eta_true_{i}'][:] = self.data['eta_obs'][i]
self.model_params[f'hyper_mean_m_{i}'] = np.amax(self.data['m_obs'][i])
self.model_params[f'hyper_sigma_m_{i}'] = np.amax(self.data['m_obs'][i]) - np.percentile(self.data['m_obs'][i], 16)
self.model_params[f'hyper_mean_eta_{i}'] = np.median(self.data['eta_obs'][i])
self.model_params[f'hyper_sigma_eta_{i}'] = (np.percentile(self.data['eta_obs'][i], 84) - np.percentile(self.data['eta_obs'][i], 16)) / 2
elif self.data['tracer_type'][i] == 'sn':
state[f'stretch_true_{i}'][:] = self.data['stretch_obs'][i]
state[f'c_true_{i}'][:] = self.data['c_obs'][i]
self.model_params[f'hyper_mean_m_{i}'] = np.median(self.data['m_obs'][i])
self.model_params[f'hyper_sigma_m_{i}'] = (np.percentile(self.data['m_obs'][i], 84) - np.percentile(self.data['m_obs'][i], 16)) / 2
self.model_params[f'hyper_mean_stretch_{i}'] = np.median(self.data['stretch_obs'][i])
self.model_params[f'hyper_sigma_stretch_{i}'] = (np.percentile(self.data['stretch_obs'][i], 84) - np.percentile(self.data['stretch_obs'][i], 16)) / 2
self.model_params[f'hyper_mean_c_{i}'] = np.median(self.data['c_obs'][i])
self.model_params[f'hyper_sigma_stretch_{i}'] = (np.percentile(self.data['c_obs'][i], 84) - np.percentile(self.data['c_obs'][i], 16)) / 2
else:
raise ValueError(f"Unknown tracer type: {self.data['tracer_type'][i]}")
self.fwd_param.setModelParams(self.model_params)
for k in self.vector_model_params.keys():
self.vector_model_params[k] = state[k]
# Save mock to file
# with h5py.File(f'tracer_data_{self.run_type}.h5', 'w') as h5file:
# for i in range(self.nsamp):
@ -213,23 +316,43 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
lkl = 0
sig_v = self.model_params['sig_v']
sigma_v = self.model_params['sig_v']
# Compute velocity field
self.bulk_flow = jnp.array([self.model_params['bulk_flow_x'],
bulk_flow = jnp.array([self.model_params['bulk_flow_x'],
self.model_params['bulk_flow_y'],
self.model_params['bulk_flow_z']])
v = output_velocity + self.bulk_flow.reshape((3, 1, 1, 1))
vel = output_velocity + bulk_flow.reshape((3, 1, 1, 1))
omega_m = self.fwd.getCosmoParams().omega_m
h = self.fwd.getCosmoParams().h
raise NotImplementedError
loglike = 0.
for i in range(self.nsamp):
if self.data['tracer_type'][i] == 'tfr':
loglike += tracer_likelihoods.tfr_likelihood(
output_density, vel, self.model_params[f'alpha_{i}'],
self.model_params[f'a_TFR_{i}'], self.model_params[f'b_TFR_{i}'], self.model_params[f'sigma_TFR_{i}'], sigma_v,
self.vector_model_params[f'm_true_{i}'], self.vector_model_params[f'eta_true_{i}'],
omega_m, h, self.L[0], self.corner[0], self.interp_order, self.bias_epsilon,
self.data['cz_obs'][i], self.data['m_obs'][i], self.data['eta_obs'][i],
self.data['sigma_m'][i], self.data['sigma_eta'][i], self.MB_pos[i], self.data['mthresh'][i])
elif self.data['tracer_type'][i] == 'sn':
loglike += tracer_likelihoods.sn_likelihood(
output_density, vel, self.model_params[f'alpha_{i}'],
self.model_params[f'a_tripp_{i}'], self.model_params[f'b_tripp_{i}'],
self.model_params[f'M_SN_{i}'], self.model_params[f'sigma_SN_{i}'], sigma_v,
self.vector_model_params[f'm_true_{i}'], self.vector_model_params[f'stretch_true_{i}'], self.vector_model_params[f'c_true_{i}'],
omega_m, h, self.L[0], self.corner[0], self.interp_order, self.bias_epsilon,
self.data['cz_obs'][i], self.data['m_obs'][i], self.data['stretch_obs'][i], self.data['c_obs'][i],
self.data['sigma_m'][i], self.data['sigma_stretch'][i], self.data['sigma_c'][i], self.MB_pos[i])
else:
raise ValueError(f"Unknown tracer type: {self.data['tracer_type'][i]}")
# Add in bulk flow prior
lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(self.sigma_bulk / jnp.sqrt(3)) + self.bulk_flow ** 2 / 2. / (self.sigma_bulk / jnp.sqrt(3)) ** 2)
lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(self.sigma_bulk / jnp.sqrt(3)) + bulk_flow ** 2 / 2. / (self.sigma_bulk / jnp.sqrt(3)) ** 2)
# lkl = jnp.clip(lkl, -self.bignum, self.bignum)
lkl = lax.cond(
lkl = jax.lax.cond(
jnp.isfinite(lkl),
lambda _: lkl, # If True (finite), return lkl
lambda _: self.bignum, # If False (not finite), return self.bignum
@ -697,12 +820,14 @@ def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.L
"""
To Do:
- Add in the true colours etc. to the model parameter sampling
- Call to the correct log-likelihoods
- Sample the variables from priors
- Check global variables are used
- Create minimal ini file
- Docstrings
- Don't want all these samplers - only gradient ones
- Save mock to file
- How to sample the true values now they are not in model_params?
- How too update the vector of true values for likelihood call?
- Commit true values to file
- Check changes to sigma_m etc. likelihoods in original tests
"""

View file

@ -0,0 +1,377 @@
import jax
import jax.numpy as jnp
import borg_velocity.projection as projection
import borg_velocity.utils as utils
def tfr_likelihood_vel(dens, vel, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true,
omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, MB_pos, mthresh):
"""
Evaluate the terms in the likelihood from the velocity and malmquist bias for TFR tracers
Args:
- alpha (float): Exponent for bias model
- a_TFR (float): TFR relation intercept
- b_TFR (float): TFR relation slope
- sigma_TFR (float): Intrinsic scatter in the TFR
- sigma_v (float): Uncertainty on the velocity field (km/s)
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,))
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- omega_m (float): Matter density parameter Om
- h (float): Hubble constant H0 = 100 h km/s/Mpc
- L (float): Comoving box size (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- interp_order (int): Order of interpolation from grid points to the line of sight
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
- mthresh (float): Threshold absolute magnitude in selection
Returns:
- loglike (float): The log-likelihood of the data
"""
# Comoving radii of integration points (Mpc/h)
r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0))
# p_r = r^2 n(r) N(mutrue; muTFR, sigmaTFR)
# Multiply by arbitrary number for numerical stability (cancels in p_r / p_r_norm)
number_density = projection.interp_field(
dens,
MB_pos,
L,
jnp.array([xmin, xmin, xmin]),
interp_order,
use_jitted=True,
)
number_density = jax.nn.relu(1. + number_density)
number_density = jnp.power(number_density + bias_epsilon, alpha)
zcosmo = utils.z_cos(r, omega_m)
mutrue = 5 * jnp.log10((1 + zcosmo) * r / h) + 25
muTFR = m_true - (a_TFR + b_TFR * eta_true)
d2 = ((mutrue - muTFR[:,None]) / sigma_TFR) ** 2
best = jnp.amin(jnp.abs(d2), axis=1)
d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1)
p_r = r ** 2 * jnp.exp(-0.5 * d2) * number_density
p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1)
# Peculiar velocity term
tracer_vel = projection.interp_field(
vel,
MB_pos,
L,
jnp.array([xmin, xmin, xmin]),
interp_order,
use_jitted=True,
)
tracer_vr = projection.project_radial(
tracer_vel,
MB_pos,
jnp.zeros(3,)
)
cz_pred = ((1 + zcosmo) * (1 + tracer_vr / utils.speed_of_light) - 1) * utils.speed_of_light
d2 = ((cz_pred - jnp.expand_dims(cz_obs, axis=1)) / sigma_v)**2
scale = jnp.nanmin(d2, axis=1)
d2 = d2 - jnp.expand_dims(scale, axis=1)
# Integrate to get likelihood
p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1)
lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * jnp.pi * sigma_v**2)
loglike = lkl_ind.sum()
return loglike
def tfr_likelihood_m(m_true, m_obs, sigma_m, mthresh):
"""
Evaluate the terms in the likelihood from apparent magnitude for TFR tracers
Args:
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- mthresh (float): Threshold absolute magnitude in selection
Returns:
- loglike (float): The log-likelihood of the data
"""
Nt = m_obs.shape[0]
norm = (
jnp.log(2)
- jnp.log(jax.scipy.special.erfc(- (mthresh - m_true) / (jnp.sqrt(2) * sigma_m)))
- 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)
)
loglike = - (
jnp.sum(0.5 * (m_obs - m_true) ** 2 / sigma_m ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2))
- jnp.sum(norm)
)
return loglike
def tfr_likelihood_eta(eta_true, eta_obs, sigma_eta):
"""
Evaluate the terms in the likelihood from linewidth for TFR tracers
Args:
- eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,))
- eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,))
- sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
loglike = - (
jnp.sum(0.5 * (eta_obs - eta_true) ** 2 / sigma_eta ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2))
)
return loglike
def tfr_likelihood(dens, vel, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true,
omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh):
"""
Evaluate the likelihood for TFR sample
Args:
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- alpha (float): Exponent for bias model
- a_TFR (float): TFR relation intercept
- b_TFR (float): TFR relation slope
- sigma_TFR (float): Intrinsic scatter in the TFR
- sigma_v (float): Uncertainty on the velocity field (km/s)
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,))
- omega_m (float): Matter density parameter Om
- h (float): Hubble constant H0 = 100 h km/s/Mpc
- L (float): Comoving box size (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- interp_order (int): Order of interpolation from grid points to the line of sight
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,))
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
- mthresh (float): Threshold absolute magnitude in selection
Returns:
- loglike (float): The log-likelihood of the data
"""
loglike_vel = tfr_likelihood_vel(dens, vel,
alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true,
omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, MB_pos, mthresh)
loglike_m = tfr_likelihood_m(m_true, m_obs, sigma_m, mthresh)
loglike_eta = tfr_likelihood_eta(eta_true, eta_obs, sigma_eta)
loglike = (loglike_vel + loglike_m + loglike_eta)
return loglike
def sn_likelihood_vel(dens, vel, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true,
omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, MB_pos):
"""
Evaluate the terms in the likelihood from the velocity and malmquist bias for supernovae
Args:
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- alpha (float): Exponent for bias model
- a_tripp (float): Coefficient of stretch in the Tripp relation
- b_tripp (float): Coefficient of colour in the Tripp relation
- M_SN (float): Absolute magnitude of supernovae
- sigma_SN (float): Intrinsic scatter in the Tripp relation
- sigma_v (float): Uncertainty on the velocity field (km/s)
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,))
- c_true (np.ndarray): True colour values of the tracers (shape = (Nt,))
- omega_m (float): Matter density parameter Om
- h (float): Hubble constant H0 = 100 h km/s/Mpc
- L (float): Comoving box size (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- interp_order (int): Order of interpolation from grid points to the line of sight
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
Returns:
- loglike (float): The log-likelihood of the data
"""
# Comoving radii of integration points (Mpc/h)
r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0))
# p_r = r^2 n(r) N(mutrue; muTFR, sigmaTFR)
# Multiply by arbitrary number for numerical stability (cancels in p_r / p_r_norm)
number_density = projection.interp_field(
dens,
MB_pos,
L,
jnp.array([xmin, xmin, xmin]),
interp_order,
use_jitted=True,
)
number_density = jax.nn.relu(1. + number_density)
number_density = jnp.power(number_density + bias_epsilon, alpha)
zcosmo = utils.z_cos(r, omega_m)
mutrue = 5 * jnp.log10((1 + zcosmo) * r / h) + 25
mutripp = m_true + a_tripp * stretch_true - b_tripp * c_true - M_SN
d2 = ((mutrue - mutripp[:,None]) / sigma_SN) ** 2
best = jnp.amin(jnp.abs(d2), axis=1)
d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1)
p_r = r ** 2 * jnp.exp(-0.5 * d2) * number_density
p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1)
# Peculiar velocity term
tracer_vel = projection.interp_field(
vel,
MB_pos,
L,
jnp.array([xmin, xmin, xmin]),
interp_order,
use_jitted=True,
)
tracer_vr = projection.project_radial(
tracer_vel,
MB_pos,
jnp.zeros(3,)
)
cz_pred = ((1 + zcosmo) * (1 + tracer_vr / utils.speed_of_light) - 1) * utils.speed_of_light
d2 = ((cz_pred - jnp.expand_dims(cz_obs, axis=1)) / sigma_v)**2
scale = jnp.nanmin(d2, axis=1)
d2 = d2 - jnp.expand_dims(scale, axis=1)
# Integrate to get likelihood
p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1)
lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * jnp.pi * sigma_v**2)
loglike = lkl_ind.sum()
return loglike
def sn_likelihood_stretch(stretch_true, stretch_obs, sigma_stretch):
"""
Evaluate the terms in the likelihood from stretch for supernovae
Args:
- stretch_true (np.ndarray): True stretch of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch of the tracers (shape = (Nt,))
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
loglike = - (
jnp.sum(0.5 * (stretch_obs - stretch_true) ** 2 / sigma_stretch ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2))
)
return loglike
def sn_likelihood_c(c_true, c_obs, sigma_c):
"""
Evaluate the terms in the likelihood from colour for supernovae
Args:
- c_true (np.ndarray): True colours of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colours of the tracers (shape = (Nt,))
- sigma_c (float or np.ndarray): Uncertainty on the colours measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
loglike = - (
jnp.sum(0.5 * (c_obs - c_true) ** 2 / sigma_c ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2))
)
return loglike
def sn_likelihood_m(m_true, m_obs, sigma_m):
"""
Evaluate the terms in the likelihood from apparent magntiude for supernovae
Args:
- m_true (np.ndarray): True apparent magnitude of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitude of the tracers (shape = (Nt,))
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
Returns:
- loglike (float): The log-likelihood of the data
"""
loglike = - (
jnp.sum(0.5 * (m_obs - m_true) ** 2 / sigma_m ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2))
)
return loglike
def sn_likelihood(dens, vel, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true,
omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos):
"""
Evaluate the likelihood for SN sample
Args:
- dens (np.ndarray): Over-density field (shape = (N, N, N))
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
- alpha (float): Exponent for bias model
- a_tripp (float): Coefficient of stretch in the Tripp relation
- b_tripp (float): Coefficient of colour in the Tripp relation
- M_SN (float): Absolute magnitude of supernovae
- sigma_SN (float): Intrinsic scatter in the Tripp relation
- sigma_v (float): Uncertainty on the velocity field (km/s)
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
- stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,))
- c_true (np.ndarray): True colour values of the tracers (shape = (Nt,))
- omega_m (float): Matter density parameter Om
- h (float): Hubble constant H0 = 100 h km/s/Mpc
- L (float): Comoving box size (Mpc/h)
- xmin (float): Coordinate of corner of the box (Mpc/h)
- interp_order (int): Order of interpolation from grid points to the line of sight
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
- stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,))
- c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,))
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
- sigma_c (float or np.ndarray): Uncertainty on the colour measurements
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
The shape is (3, Nt, Nsig)
Returns:
- loglike (float): The log-likelihood of the data
"""
loglike_vel = sn_likelihood_vel(dens, vel,
alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true,
omega_m, h, L, xmin, interp_order, bias_epsilon,
cz_obs, MB_pos)
loglike_stretch = sn_likelihood_stretch(stretch_true, stretch_obs, sigma_stretch)
loglike_c = sn_likelihood_c(c_true, c_obs, sigma_c)
loglike_m = sn_likelihood_m(m_true, m_obs, sigma_m)
loglike = loglike_vel + loglike_stretch + loglike_c + loglike_m
return loglike