Successfully call loglike from BORG
This commit is contained in:
parent
8d74403002
commit
b4bf734782
3 changed files with 539 additions and 22 deletions
|
@ -376,6 +376,11 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None):
|
|||
c_true = [None] * nsamp
|
||||
c_obs = [None] * nsamp
|
||||
xtrue = [None] * nsamp
|
||||
all_sigma_m = [None] * nsamp
|
||||
all_sigma_eta = [None] * nsamp
|
||||
all_sigma_stretch = [None] * nsamp
|
||||
all_sigma_c = [None] * nsamp
|
||||
all_mthresh = [None] * nsamp
|
||||
|
||||
for i in range(nsamp):
|
||||
|
||||
|
@ -397,12 +402,17 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None):
|
|||
hyper_eta_mu = float(config[f'sample_{i}']['hyper_eta_mu'])
|
||||
hyper_eta_sigma = float(config[f'sample_{i}']['hyper_eta_sigma'])
|
||||
mthresh = float(config[f'sample_{i}']['mthresh'])
|
||||
|
||||
RA[i], Dec[i], czCMB[i], m_true[i], eta_true[i], m_obs[i], eta_obs[i], xtrue[i] = tfr_create_mock(
|
||||
Nt, L, xmin, cosmo, dens, vel + vbulk,
|
||||
Rmax, alpha, mthresh,
|
||||
a_TFR, b_TFR, sigma_TFR, sigma_m, sigma_eta,
|
||||
hyper_eta_mu, hyper_eta_sigma, sigma_v,
|
||||
interp_order=interp_order, bias_epsilon=bias_epsilon)
|
||||
all_sigma_m[i] = np.full(Nt, sigma_m)
|
||||
all_sigma_eta[i] = np.full(Nt, sigma_eta)
|
||||
all_mthresh[i] = mthresh
|
||||
|
||||
elif tracer_type[i] == 'sn':
|
||||
a_tripp = float(config[f'sample_{i}']['a_tripp'])
|
||||
b_tripp = float(config[f'sample_{i}']['b_tripp'])
|
||||
|
@ -415,12 +425,17 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None):
|
|||
hyper_stretch_sigma = float(config[f'sample_{i}']['hyper_stretch_sigma'])
|
||||
hyper_c_mu = float(config[f'sample_{i}']['hyper_c_mu'])
|
||||
hyper_c_sigma = float(config[f'sample_{i}']['hyper_c_sigma'])
|
||||
|
||||
RA[i], Dec[i], czCMB[i], m_true[i], stretch_true[i], c_true[i], m_obs[i], stretch_obs[i], c_obs[i], xtrue[i] = sn_create_mock(
|
||||
Nt, L, xmin, cosmo, dens, vel + vbulk, Rmax, alpha,
|
||||
a_tripp, b_tripp, M_SN, sigma_SN, sigma_m, sigma_stretch, sigma_c,
|
||||
hyper_stretch_mu, hyper_stretch_sigma, hyper_c_mu, hyper_c_sigma,
|
||||
sigma_v, interp_order=1, bias_epsilon=1e-7)
|
||||
all_sigma_m[i] = np.full(Nt, sigma_m)
|
||||
all_sigma_stretch[i] = np.full(Nt, sigma_stretch)
|
||||
all_sigma_c[i] = np.full(Nt, sigma_c)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return tracer_type, RA, Dec, czCMB, m_true, m_obs, eta_true, eta_obs, stretch_true, stretch_obs, c_true, c_obs, xtrue
|
||||
|
||||
return tracer_type, RA, Dec, czCMB, m_true, m_obs, eta_true, eta_obs, stretch_true, stretch_obs, c_true, c_obs, xtrue, all_sigma_m, all_sigma_eta, all_sigma_stretch, all_sigma_c, all_mthresh
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
import aquila_borg as borg
|
||||
import jax
|
||||
import numpy as np
|
||||
import jaxlib
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import configparser
|
||||
import h5py
|
||||
import warnings
|
||||
import ast
|
||||
import re
|
||||
from functools import partial
|
||||
import numbers
|
||||
|
||||
import borg_velocity.utils as utils
|
||||
from borg_velocity.utils import myprint, generateMBData, compute_As, get_sigma_bulk
|
||||
from borg_velocity.borg_mock import borg_mock
|
||||
import borg_velocity.forwards as forwards
|
||||
import borg_velocity.tracer_likelihoods as tracer_likelihoods
|
||||
from borg_velocity.samplers import HMCBiasSampler, derive_prior, MVSliceBiasSampler, BlackJaxBiasSampler, TransformedBlackJaxBiasSampler
|
||||
|
||||
class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||
|
@ -40,9 +43,12 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
# Initialise grid parameters
|
||||
self.N = [int(config['system'][f'N{i}']) for i in range(3)]
|
||||
self.L = [float(config['system'][f'L{i}']) for i in range(3)]
|
||||
self.corner = [float(config['system'][f'corner{i}']) for i in range(3)]
|
||||
|
||||
# For log-likelihood values
|
||||
self.bignum = float(config['mcmc']['bignum'])
|
||||
self.interp_order = int(config['model']['interp_order'])
|
||||
self.bias_epsilon = float(config['model']['bias_epsilon'])
|
||||
|
||||
myprint(f" Init {self.N}, {self.L}")
|
||||
super().__init__(fwd, self.N, self.L)
|
||||
|
@ -52,16 +58,59 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
self.fwd_param = param_model
|
||||
self.fwd_vel = fwd_vel
|
||||
|
||||
# Number of samples of distance tracers
|
||||
self.nsamp = int(config['run']['nsamp'])
|
||||
self.tracer_type = [config[f'sample_{i}']['tracer_type'] for i in range(self.nsamp)]
|
||||
self.Nt = [int(config[f'sample_{i}']['Nt']) for i in range(self.nsamp)]
|
||||
|
||||
# Initialise cosmological parameters
|
||||
cpar = utils.get_cosmopar(self.ini_file)
|
||||
self.fwd.setCosmoParams(cpar)
|
||||
self.fwd_param.setCosmoParams(cpar)
|
||||
self.updateCosmology(cpar)
|
||||
myprint(f"Original cosmological parameters: {self.fwd.getCosmoParams()}")
|
||||
|
||||
|
||||
# Initialise model parameters
|
||||
self.sig_v = float(config['model']['sig_v'])
|
||||
self.bulk_flow = np.array(ast.literal_eval(config['model']['bulk_flow']))
|
||||
bulk_flow = np.array(ast.literal_eval(config['model']['bulk_flow']))
|
||||
self.model_params = {
|
||||
'sig_v': float(config['model']['sig_v']),
|
||||
'bulk_flow_x': bulk_flow[0],
|
||||
'bulk_flow_y': bulk_flow[1],
|
||||
'bulk_flow_z': bulk_flow[2],
|
||||
}
|
||||
self.vector_model_params = {}
|
||||
for i in range(self.nsamp):
|
||||
tracer_type = config[f'sample_{i}']['tracer_type']
|
||||
Nt = int(config[f'sample_{i}']['Nt'])
|
||||
self.model_params[f'alpha_{i}'] = float(config[f'sample_{i}']['alpha'])
|
||||
self.vector_model_params[f'm_true_{i}'] = np.zeros(Nt)
|
||||
self.model_params[f'hyper_mean_m_{i}'] = 0.
|
||||
self.model_params[f'hyper_sigma_m_{i}'] = 1.
|
||||
if tracer_type == 'tfr':
|
||||
self.vector_model_params[f'eta_true_{i}'] = np.zeros(Nt)
|
||||
self.model_params[f'a_TFR_{i}'] = float(config[f'sample_{i}']['a_tfr'])
|
||||
self.model_params[f'b_TFR_{i}'] = float(config[f'sample_{i}']['b_tfr'])
|
||||
self.model_params[f'sigma_TFR_{i}'] = float(config[f'sample_{i}']['sigma_tfr'])
|
||||
self.model_params[f'hyper_mean_eta_{i}'] = 0.
|
||||
self.model_params[f'hyper_sigma_eta_{i}'] = 1.
|
||||
self.model_params[f'hyper_corr_meta_{i}'] = 0.
|
||||
elif tracer_type == 'sn':
|
||||
self.vector_model_params[f'stretch_true_{i}'] = np.zeros(Nt)
|
||||
self.vector_model_params[f'c_true_{i}'] = np.zeros(Nt)
|
||||
self.model_params[f'a_tripp_{i}'] = float(config[f'sample_{i}']['a_tripp'])
|
||||
self.model_params[f'b_tripp_{i}'] = float(config[f'sample_{i}']['b_tripp'])
|
||||
self.model_params[f'M_SN_{i}'] = float(config[f'sample_{i}']['m_sn'])
|
||||
self.model_params[f'sigma_SN_{i}'] = float(config[f'sample_{i}']['sigma_sn'])
|
||||
self.model_params[f'hyper_mean_stretch_{i}'] = 0.
|
||||
self.model_params[f'hyper_sigma_stretch_{i}'] = 1.
|
||||
self.model_params[f'hyper_mean_c_{i}'] = 0.
|
||||
self.model_params[f'hyper_sigma_c_{i}'] = 1.
|
||||
self.model_params[f'hyper_corr_mx_{i}'] = 0.
|
||||
self.model_params[f'hyper_corr_mc_{i}'] = 0.
|
||||
self.model_params[f'hyper_corr_xc_{i}'] = 0.
|
||||
else:
|
||||
raise ValueError
|
||||
self.fwd_param.setModelParams(self.model_params)
|
||||
|
||||
# Initialise derivative
|
||||
self.grad_like = jax.grad(self.dens2like, argnums=(0,1))
|
||||
|
@ -78,9 +127,6 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
# Seed if creating mocks
|
||||
self.mock_seed = int(config['mock']['seed'])
|
||||
|
||||
# Number of samples to distance tracers
|
||||
self.nsamp = int(config['run']['nsamp'])
|
||||
|
||||
# Initialise integration variables
|
||||
if config['model']['R_lim'] == 'none':
|
||||
self.R_lim = self.L[0]/2
|
||||
|
@ -103,6 +149,17 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
myprint("Init likelihood")
|
||||
state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True)
|
||||
|
||||
for i in range(self.nsamp):
|
||||
if self.tracer_type[i] == 'tfr':
|
||||
state.newArray1d(f"m_true_{i}", self.Nt[i], True)
|
||||
state.newArray1d(f"eta_true_{i}", self.Nt[i], True)
|
||||
elif self.tracer_type[i] == 'sn':
|
||||
state.newArray1d(f"m_true_{i}", self.Nt[i], True)
|
||||
state.newArray1d(f"stretch_true_{i}", self.Nt[i], True)
|
||||
state.newArray1d(f"c_true_{i}", self.Nt[i], True)
|
||||
else:
|
||||
raise ValueError(f"Unknown tracer type: {self.tracer_type[i]}")
|
||||
|
||||
if self.run_type == 'data':
|
||||
self.loadObservedData()
|
||||
elif borg.EMBEDDED:
|
||||
|
@ -162,21 +219,67 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||
- make_plot (bool, default=True): Whether to make diagnostic plots for the mock data generation
|
||||
"""
|
||||
|
||||
self.data = {}
|
||||
self.truths = {}
|
||||
|
||||
if self.run_type == 'data':
|
||||
raise NotImplementedError
|
||||
elif self.run_type == 'velmass':
|
||||
raise NotImplementedError
|
||||
elif self.run_type == 'mock':
|
||||
self.tracer_type, self.RA, self.Dec, self.cz_obs, self.m_true, self.m_obs, self.eta_true, self.eta_obs, self.stretch_true, self.stretch_obs, self.c_true, self.c_obs, self.xtrue = borg_mock(s_hat, state, self.fwd, self.fwd_vel, self.ini_file, seed=self.mock_seed)
|
||||
res = borg_mock(s_hat, state, self.fwd, self.fwd_vel, self.ini_file, seed=self.mock_seed)
|
||||
self.data['tracer_type'] = res[0]
|
||||
self.data['RA'] = res[1]
|
||||
self.data['Dec'] = res[2]
|
||||
self.data['cz_obs'] = res[3]
|
||||
self.truths['m_true'] = res[4]
|
||||
self.data['m_obs'] = res[5]
|
||||
self.truths['eta_true'] = res[6]
|
||||
self.data['eta_obs'] = res[7]
|
||||
self.truths['stretch_true'] = res[8]
|
||||
self.data['stretch_obs'] = res[9]
|
||||
self.truths['c_true'] = res[10]
|
||||
self.data['c_obs'] = res[11]
|
||||
self.truths['xtrue'] = res[12]
|
||||
self.data['sigma_m'] = res[13]
|
||||
self.data['sigma_eta'] = res[14]
|
||||
self.data['sigma_stretch'] = res[15]
|
||||
self.data['sigma_c'] = res[16]
|
||||
self.data['mthresh'] = res[17]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Get integration points
|
||||
for i in range(self.nsamp):
|
||||
self.MB_pos[i] = utils.generateMBData(
|
||||
self.RA[i], self.Dec[i], self.cz_obs[i], self.L[0], self.N[0],
|
||||
self.R_lim, self.Nsig, self.Nint_points, self.sig_v, self.frac_sigma_r[i])
|
||||
self.data['RA'][i], self.data['Dec'][i], self.data['cz_obs'][i], self.L[0], self.N[0],
|
||||
self.R_lim, self.Nsig, self.Nint_points, self.model_params['sig_v'], self.frac_sigma_r[i])
|
||||
|
||||
# Initialise model parameters
|
||||
for i in range(self.nsamp):
|
||||
state[f'm_true_{i}'][:] = self.data['m_obs'][i]
|
||||
if self.data['tracer_type'][i] == 'tfr':
|
||||
state[f'eta_true_{i}'][:] = self.data['eta_obs'][i]
|
||||
self.model_params[f'hyper_mean_m_{i}'] = np.amax(self.data['m_obs'][i])
|
||||
self.model_params[f'hyper_sigma_m_{i}'] = np.amax(self.data['m_obs'][i]) - np.percentile(self.data['m_obs'][i], 16)
|
||||
self.model_params[f'hyper_mean_eta_{i}'] = np.median(self.data['eta_obs'][i])
|
||||
self.model_params[f'hyper_sigma_eta_{i}'] = (np.percentile(self.data['eta_obs'][i], 84) - np.percentile(self.data['eta_obs'][i], 16)) / 2
|
||||
elif self.data['tracer_type'][i] == 'sn':
|
||||
state[f'stretch_true_{i}'][:] = self.data['stretch_obs'][i]
|
||||
state[f'c_true_{i}'][:] = self.data['c_obs'][i]
|
||||
self.model_params[f'hyper_mean_m_{i}'] = np.median(self.data['m_obs'][i])
|
||||
self.model_params[f'hyper_sigma_m_{i}'] = (np.percentile(self.data['m_obs'][i], 84) - np.percentile(self.data['m_obs'][i], 16)) / 2
|
||||
self.model_params[f'hyper_mean_stretch_{i}'] = np.median(self.data['stretch_obs'][i])
|
||||
self.model_params[f'hyper_sigma_stretch_{i}'] = (np.percentile(self.data['stretch_obs'][i], 84) - np.percentile(self.data['stretch_obs'][i], 16)) / 2
|
||||
self.model_params[f'hyper_mean_c_{i}'] = np.median(self.data['c_obs'][i])
|
||||
self.model_params[f'hyper_sigma_stretch_{i}'] = (np.percentile(self.data['c_obs'][i], 84) - np.percentile(self.data['c_obs'][i], 16)) / 2
|
||||
else:
|
||||
raise ValueError(f"Unknown tracer type: {self.data['tracer_type'][i]}")
|
||||
self.fwd_param.setModelParams(self.model_params)
|
||||
for k in self.vector_model_params.keys():
|
||||
self.vector_model_params[k] = state[k]
|
||||
|
||||
# Save mock to file
|
||||
# with h5py.File(f'tracer_data_{self.run_type}.h5', 'w') as h5file:
|
||||
# for i in range(self.nsamp):
|
||||
|
@ -213,23 +316,43 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
|
||||
lkl = 0
|
||||
|
||||
sig_v = self.model_params['sig_v']
|
||||
sigma_v = self.model_params['sig_v']
|
||||
|
||||
# Compute velocity field
|
||||
self.bulk_flow = jnp.array([self.model_params['bulk_flow_x'],
|
||||
bulk_flow = jnp.array([self.model_params['bulk_flow_x'],
|
||||
self.model_params['bulk_flow_y'],
|
||||
self.model_params['bulk_flow_z']])
|
||||
v = output_velocity + self.bulk_flow.reshape((3, 1, 1, 1))
|
||||
vel = output_velocity + bulk_flow.reshape((3, 1, 1, 1))
|
||||
|
||||
omega_m = self.fwd.getCosmoParams().omega_m
|
||||
h = self.fwd.getCosmoParams().h
|
||||
|
||||
raise NotImplementedError
|
||||
loglike = 0.
|
||||
for i in range(self.nsamp):
|
||||
if self.data['tracer_type'][i] == 'tfr':
|
||||
loglike += tracer_likelihoods.tfr_likelihood(
|
||||
output_density, vel, self.model_params[f'alpha_{i}'],
|
||||
self.model_params[f'a_TFR_{i}'], self.model_params[f'b_TFR_{i}'], self.model_params[f'sigma_TFR_{i}'], sigma_v,
|
||||
self.vector_model_params[f'm_true_{i}'], self.vector_model_params[f'eta_true_{i}'],
|
||||
omega_m, h, self.L[0], self.corner[0], self.interp_order, self.bias_epsilon,
|
||||
self.data['cz_obs'][i], self.data['m_obs'][i], self.data['eta_obs'][i],
|
||||
self.data['sigma_m'][i], self.data['sigma_eta'][i], self.MB_pos[i], self.data['mthresh'][i])
|
||||
elif self.data['tracer_type'][i] == 'sn':
|
||||
loglike += tracer_likelihoods.sn_likelihood(
|
||||
output_density, vel, self.model_params[f'alpha_{i}'],
|
||||
self.model_params[f'a_tripp_{i}'], self.model_params[f'b_tripp_{i}'],
|
||||
self.model_params[f'M_SN_{i}'], self.model_params[f'sigma_SN_{i}'], sigma_v,
|
||||
self.vector_model_params[f'm_true_{i}'], self.vector_model_params[f'stretch_true_{i}'], self.vector_model_params[f'c_true_{i}'],
|
||||
omega_m, h, self.L[0], self.corner[0], self.interp_order, self.bias_epsilon,
|
||||
self.data['cz_obs'][i], self.data['m_obs'][i], self.data['stretch_obs'][i], self.data['c_obs'][i],
|
||||
self.data['sigma_m'][i], self.data['sigma_stretch'][i], self.data['sigma_c'][i], self.MB_pos[i])
|
||||
else:
|
||||
raise ValueError(f"Unknown tracer type: {self.data['tracer_type'][i]}")
|
||||
|
||||
# Add in bulk flow prior
|
||||
lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(self.sigma_bulk / jnp.sqrt(3)) + self.bulk_flow ** 2 / 2. / (self.sigma_bulk / jnp.sqrt(3)) ** 2)
|
||||
lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(self.sigma_bulk / jnp.sqrt(3)) + bulk_flow ** 2 / 2. / (self.sigma_bulk / jnp.sqrt(3)) ** 2)
|
||||
|
||||
# lkl = jnp.clip(lkl, -self.bignum, self.bignum)
|
||||
lkl = lax.cond(
|
||||
lkl = jax.lax.cond(
|
||||
jnp.isfinite(lkl),
|
||||
lambda _: lkl, # If True (finite), return lkl
|
||||
lambda _: self.bignum, # If False (not finite), return self.bignum
|
||||
|
@ -697,12 +820,14 @@ def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.L
|
|||
"""
|
||||
To Do:
|
||||
- Add in the true colours etc. to the model parameter sampling
|
||||
- Call to the correct log-likelihoods
|
||||
- Sample the variables from priors
|
||||
- Check global variables are used
|
||||
- Create minimal ini file
|
||||
- Docstrings
|
||||
- Don't want all these samplers - only gradient ones
|
||||
- Save mock to file
|
||||
- How to sample the true values now they are not in model_params?
|
||||
- How too update the vector of true values for likelihood call?
|
||||
- Commit true values to file
|
||||
- Check changes to sigma_m etc. likelihoods in original tests
|
||||
"""
|
||||
|
||||
|
|
377
borg_velocity/tracer_likelihoods.py
Normal file
377
borg_velocity/tracer_likelihoods.py
Normal file
|
@ -0,0 +1,377 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import borg_velocity.projection as projection
|
||||
import borg_velocity.utils as utils
|
||||
|
||||
def tfr_likelihood_vel(dens, vel, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true,
|
||||
omega_m, h, L, xmin, interp_order, bias_epsilon,
|
||||
cz_obs, MB_pos, mthresh):
|
||||
"""
|
||||
Evaluate the terms in the likelihood from the velocity and malmquist bias for TFR tracers
|
||||
|
||||
Args:
|
||||
- alpha (float): Exponent for bias model
|
||||
- a_TFR (float): TFR relation intercept
|
||||
- b_TFR (float): TFR relation slope
|
||||
- sigma_TFR (float): Intrinsic scatter in the TFR
|
||||
- sigma_v (float): Uncertainty on the velocity field (km/s)
|
||||
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
|
||||
- eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,))
|
||||
- dens (np.ndarray): Over-density field (shape = (N, N, N))
|
||||
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
|
||||
- omega_m (float): Matter density parameter Om
|
||||
- h (float): Hubble constant H0 = 100 h km/s/Mpc
|
||||
- L (float): Comoving box size (Mpc/h)
|
||||
- xmin (float): Coordinate of corner of the box (Mpc/h)
|
||||
- interp_order (int): Order of interpolation from grid points to the line of sight
|
||||
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
|
||||
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
|
||||
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
|
||||
The shape is (3, Nt, Nsig)
|
||||
- mthresh (float): Threshold absolute magnitude in selection
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
# Comoving radii of integration points (Mpc/h)
|
||||
r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0))
|
||||
|
||||
# p_r = r^2 n(r) N(mutrue; muTFR, sigmaTFR)
|
||||
# Multiply by arbitrary number for numerical stability (cancels in p_r / p_r_norm)
|
||||
number_density = projection.interp_field(
|
||||
dens,
|
||||
MB_pos,
|
||||
L,
|
||||
jnp.array([xmin, xmin, xmin]),
|
||||
interp_order,
|
||||
use_jitted=True,
|
||||
)
|
||||
number_density = jax.nn.relu(1. + number_density)
|
||||
number_density = jnp.power(number_density + bias_epsilon, alpha)
|
||||
zcosmo = utils.z_cos(r, omega_m)
|
||||
mutrue = 5 * jnp.log10((1 + zcosmo) * r / h) + 25
|
||||
muTFR = m_true - (a_TFR + b_TFR * eta_true)
|
||||
d2 = ((mutrue - muTFR[:,None]) / sigma_TFR) ** 2
|
||||
best = jnp.amin(jnp.abs(d2), axis=1)
|
||||
d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1)
|
||||
p_r = r ** 2 * jnp.exp(-0.5 * d2) * number_density
|
||||
p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1)
|
||||
|
||||
# Peculiar velocity term
|
||||
tracer_vel = projection.interp_field(
|
||||
vel,
|
||||
MB_pos,
|
||||
L,
|
||||
jnp.array([xmin, xmin, xmin]),
|
||||
interp_order,
|
||||
use_jitted=True,
|
||||
)
|
||||
tracer_vr = projection.project_radial(
|
||||
tracer_vel,
|
||||
MB_pos,
|
||||
jnp.zeros(3,)
|
||||
)
|
||||
cz_pred = ((1 + zcosmo) * (1 + tracer_vr / utils.speed_of_light) - 1) * utils.speed_of_light
|
||||
d2 = ((cz_pred - jnp.expand_dims(cz_obs, axis=1)) / sigma_v)**2
|
||||
scale = jnp.nanmin(d2, axis=1)
|
||||
d2 = d2 - jnp.expand_dims(scale, axis=1)
|
||||
|
||||
# Integrate to get likelihood
|
||||
p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1)
|
||||
lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * jnp.pi * sigma_v**2)
|
||||
loglike = lkl_ind.sum()
|
||||
|
||||
return loglike
|
||||
|
||||
|
||||
def tfr_likelihood_m(m_true, m_obs, sigma_m, mthresh):
|
||||
"""
|
||||
Evaluate the terms in the likelihood from apparent magnitude for TFR tracers
|
||||
|
||||
Args:
|
||||
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
|
||||
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
|
||||
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
|
||||
- mthresh (float): Threshold absolute magnitude in selection
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
Nt = m_obs.shape[0]
|
||||
norm = (
|
||||
jnp.log(2)
|
||||
- jnp.log(jax.scipy.special.erfc(- (mthresh - m_true) / (jnp.sqrt(2) * sigma_m)))
|
||||
- 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)
|
||||
)
|
||||
loglike = - (
|
||||
jnp.sum(0.5 * (m_obs - m_true) ** 2 / sigma_m ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2))
|
||||
- jnp.sum(norm)
|
||||
)
|
||||
|
||||
return loglike
|
||||
|
||||
|
||||
def tfr_likelihood_eta(eta_true, eta_obs, sigma_eta):
|
||||
"""
|
||||
Evaluate the terms in the likelihood from linewidth for TFR tracers
|
||||
|
||||
Args:
|
||||
- eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,))
|
||||
- eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,))
|
||||
- sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
loglike = - (
|
||||
jnp.sum(0.5 * (eta_obs - eta_true) ** 2 / sigma_eta ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2))
|
||||
)
|
||||
|
||||
return loglike
|
||||
|
||||
|
||||
def tfr_likelihood(dens, vel, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true,
|
||||
omega_m, h, L, xmin, interp_order, bias_epsilon,
|
||||
cz_obs, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh):
|
||||
"""
|
||||
Evaluate the likelihood for TFR sample
|
||||
|
||||
Args:
|
||||
- dens (np.ndarray): Over-density field (shape = (N, N, N))
|
||||
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
|
||||
- alpha (float): Exponent for bias model
|
||||
- a_TFR (float): TFR relation intercept
|
||||
- b_TFR (float): TFR relation slope
|
||||
- sigma_TFR (float): Intrinsic scatter in the TFR
|
||||
- sigma_v (float): Uncertainty on the velocity field (km/s)
|
||||
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
|
||||
- eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,))
|
||||
- omega_m (float): Matter density parameter Om
|
||||
- h (float): Hubble constant H0 = 100 h km/s/Mpc
|
||||
- L (float): Comoving box size (Mpc/h)
|
||||
- xmin (float): Coordinate of corner of the box (Mpc/h)
|
||||
- interp_order (int): Order of interpolation from grid points to the line of sight
|
||||
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
|
||||
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
|
||||
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
|
||||
- eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,))
|
||||
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
|
||||
- sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements
|
||||
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
|
||||
The shape is (3, Nt, Nsig)
|
||||
- mthresh (float): Threshold absolute magnitude in selection
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
|
||||
loglike_vel = tfr_likelihood_vel(dens, vel,
|
||||
alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true,
|
||||
omega_m, h, L, xmin, interp_order, bias_epsilon,
|
||||
cz_obs, MB_pos, mthresh)
|
||||
loglike_m = tfr_likelihood_m(m_true, m_obs, sigma_m, mthresh)
|
||||
loglike_eta = tfr_likelihood_eta(eta_true, eta_obs, sigma_eta)
|
||||
|
||||
loglike = (loglike_vel + loglike_m + loglike_eta)
|
||||
|
||||
return loglike
|
||||
|
||||
|
||||
def sn_likelihood_vel(dens, vel, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true,
|
||||
omega_m, h, L, xmin, interp_order, bias_epsilon,
|
||||
cz_obs, MB_pos):
|
||||
"""
|
||||
Evaluate the terms in the likelihood from the velocity and malmquist bias for supernovae
|
||||
|
||||
Args:
|
||||
- dens (np.ndarray): Over-density field (shape = (N, N, N))
|
||||
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
|
||||
- alpha (float): Exponent for bias model
|
||||
- a_tripp (float): Coefficient of stretch in the Tripp relation
|
||||
- b_tripp (float): Coefficient of colour in the Tripp relation
|
||||
- M_SN (float): Absolute magnitude of supernovae
|
||||
- sigma_SN (float): Intrinsic scatter in the Tripp relation
|
||||
- sigma_v (float): Uncertainty on the velocity field (km/s)
|
||||
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
|
||||
- stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,))
|
||||
- c_true (np.ndarray): True colour values of the tracers (shape = (Nt,))
|
||||
- omega_m (float): Matter density parameter Om
|
||||
- h (float): Hubble constant H0 = 100 h km/s/Mpc
|
||||
- L (float): Comoving box size (Mpc/h)
|
||||
- xmin (float): Coordinate of corner of the box (Mpc/h)
|
||||
- interp_order (int): Order of interpolation from grid points to the line of sight
|
||||
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
|
||||
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
|
||||
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
|
||||
The shape is (3, Nt, Nsig)
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
# Comoving radii of integration points (Mpc/h)
|
||||
r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0))
|
||||
|
||||
# p_r = r^2 n(r) N(mutrue; muTFR, sigmaTFR)
|
||||
# Multiply by arbitrary number for numerical stability (cancels in p_r / p_r_norm)
|
||||
number_density = projection.interp_field(
|
||||
dens,
|
||||
MB_pos,
|
||||
L,
|
||||
jnp.array([xmin, xmin, xmin]),
|
||||
interp_order,
|
||||
use_jitted=True,
|
||||
)
|
||||
number_density = jax.nn.relu(1. + number_density)
|
||||
number_density = jnp.power(number_density + bias_epsilon, alpha)
|
||||
zcosmo = utils.z_cos(r, omega_m)
|
||||
mutrue = 5 * jnp.log10((1 + zcosmo) * r / h) + 25
|
||||
mutripp = m_true + a_tripp * stretch_true - b_tripp * c_true - M_SN
|
||||
d2 = ((mutrue - mutripp[:,None]) / sigma_SN) ** 2
|
||||
best = jnp.amin(jnp.abs(d2), axis=1)
|
||||
d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1)
|
||||
p_r = r ** 2 * jnp.exp(-0.5 * d2) * number_density
|
||||
p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1)
|
||||
|
||||
# Peculiar velocity term
|
||||
tracer_vel = projection.interp_field(
|
||||
vel,
|
||||
MB_pos,
|
||||
L,
|
||||
jnp.array([xmin, xmin, xmin]),
|
||||
interp_order,
|
||||
use_jitted=True,
|
||||
)
|
||||
tracer_vr = projection.project_radial(
|
||||
tracer_vel,
|
||||
MB_pos,
|
||||
jnp.zeros(3,)
|
||||
)
|
||||
cz_pred = ((1 + zcosmo) * (1 + tracer_vr / utils.speed_of_light) - 1) * utils.speed_of_light
|
||||
d2 = ((cz_pred - jnp.expand_dims(cz_obs, axis=1)) / sigma_v)**2
|
||||
scale = jnp.nanmin(d2, axis=1)
|
||||
d2 = d2 - jnp.expand_dims(scale, axis=1)
|
||||
|
||||
# Integrate to get likelihood
|
||||
p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1)
|
||||
lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * jnp.pi * sigma_v**2)
|
||||
loglike = lkl_ind.sum()
|
||||
|
||||
return loglike
|
||||
|
||||
|
||||
def sn_likelihood_stretch(stretch_true, stretch_obs, sigma_stretch):
|
||||
"""
|
||||
Evaluate the terms in the likelihood from stretch for supernovae
|
||||
|
||||
Args:
|
||||
- stretch_true (np.ndarray): True stretch of the tracers (shape = (Nt,))
|
||||
- stretch_obs (np.ndarray): Observed stretch of the tracers (shape = (Nt,))
|
||||
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
loglike = - (
|
||||
jnp.sum(0.5 * (stretch_obs - stretch_true) ** 2 / sigma_stretch ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2))
|
||||
)
|
||||
|
||||
return loglike
|
||||
|
||||
|
||||
def sn_likelihood_c(c_true, c_obs, sigma_c):
|
||||
"""
|
||||
Evaluate the terms in the likelihood from colour for supernovae
|
||||
|
||||
Args:
|
||||
- c_true (np.ndarray): True colours of the tracers (shape = (Nt,))
|
||||
- c_obs (np.ndarray): Observed colours of the tracers (shape = (Nt,))
|
||||
- sigma_c (float or np.ndarray): Uncertainty on the colours measurements
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
loglike = - (
|
||||
jnp.sum(0.5 * (c_obs - c_true) ** 2 / sigma_c ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2))
|
||||
)
|
||||
|
||||
return loglike
|
||||
|
||||
|
||||
def sn_likelihood_m(m_true, m_obs, sigma_m):
|
||||
"""
|
||||
Evaluate the terms in the likelihood from apparent magntiude for supernovae
|
||||
|
||||
Args:
|
||||
- m_true (np.ndarray): True apparent magnitude of the tracers (shape = (Nt,))
|
||||
- m_obs (np.ndarray): Observed apparent magnitude of the tracers (shape = (Nt,))
|
||||
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
loglike = - (
|
||||
jnp.sum(0.5 * (m_obs - m_true) ** 2 / sigma_m ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2))
|
||||
)
|
||||
|
||||
return loglike
|
||||
|
||||
|
||||
def sn_likelihood(dens, vel, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true,
|
||||
omega_m, h, L, xmin, interp_order, bias_epsilon,
|
||||
cz_obs, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos):
|
||||
"""
|
||||
Evaluate the likelihood for SN sample
|
||||
|
||||
Args:
|
||||
- dens (np.ndarray): Over-density field (shape = (N, N, N))
|
||||
- vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N))
|
||||
- alpha (float): Exponent for bias model
|
||||
- a_tripp (float): Coefficient of stretch in the Tripp relation
|
||||
- b_tripp (float): Coefficient of colour in the Tripp relation
|
||||
- M_SN (float): Absolute magnitude of supernovae
|
||||
- sigma_SN (float): Intrinsic scatter in the Tripp relation
|
||||
- sigma_v (float): Uncertainty on the velocity field (km/s)
|
||||
- m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,))
|
||||
- stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,))
|
||||
- c_true (np.ndarray): True colour values of the tracers (shape = (Nt,))
|
||||
- omega_m (float): Matter density parameter Om
|
||||
- h (float): Hubble constant H0 = 100 h km/s/Mpc
|
||||
- L (float): Comoving box size (Mpc/h)
|
||||
- xmin (float): Coordinate of corner of the box (Mpc/h)
|
||||
- interp_order (int): Order of interpolation from grid points to the line of sight
|
||||
- bias_epsilon (float): Small number to add to 1 + delta to prevent 0^#
|
||||
- cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,))
|
||||
- m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,))
|
||||
- stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,))
|
||||
- c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,))
|
||||
- sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements
|
||||
- sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements
|
||||
- sigma_c (float or np.ndarray): Uncertainty on the colour measurements
|
||||
- MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h).
|
||||
The shape is (3, Nt, Nsig)
|
||||
|
||||
Returns:
|
||||
- loglike (float): The log-likelihood of the data
|
||||
"""
|
||||
|
||||
|
||||
loglike_vel = sn_likelihood_vel(dens, vel,
|
||||
alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true,
|
||||
omega_m, h, L, xmin, interp_order, bias_epsilon,
|
||||
cz_obs, MB_pos)
|
||||
loglike_stretch = sn_likelihood_stretch(stretch_true, stretch_obs, sigma_stretch)
|
||||
loglike_c = sn_likelihood_c(c_true, c_obs, sigma_c)
|
||||
loglike_m = sn_likelihood_m(m_true, m_obs, sigma_m)
|
||||
|
||||
loglike = loglike_vel + loglike_stretch + loglike_c + loglike_m
|
||||
|
||||
return loglike
|
Loading…
Add table
Reference in a new issue