From b4bf7347826fae8e3d099a7034e08c5c033d03e0 Mon Sep 17 00:00:00 2001 From: Deaglan Bartlett Date: Tue, 11 Feb 2025 12:00:59 +0100 Subject: [PATCH] Successfully call loglike from BORG --- borg_velocity/borg_mock.py | 19 +- borg_velocity/likelihood.py | 165 ++++++++++-- borg_velocity/tracer_likelihoods.py | 377 ++++++++++++++++++++++++++++ 3 files changed, 539 insertions(+), 22 deletions(-) create mode 100644 borg_velocity/tracer_likelihoods.py diff --git a/borg_velocity/borg_mock.py b/borg_velocity/borg_mock.py index ee4bfdc..bc4a6ca 100644 --- a/borg_velocity/borg_mock.py +++ b/borg_velocity/borg_mock.py @@ -376,6 +376,11 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None): c_true = [None] * nsamp c_obs = [None] * nsamp xtrue = [None] * nsamp + all_sigma_m = [None] * nsamp + all_sigma_eta = [None] * nsamp + all_sigma_stretch = [None] * nsamp + all_sigma_c = [None] * nsamp + all_mthresh = [None] * nsamp for i in range(nsamp): @@ -397,12 +402,17 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None): hyper_eta_mu = float(config[f'sample_{i}']['hyper_eta_mu']) hyper_eta_sigma = float(config[f'sample_{i}']['hyper_eta_sigma']) mthresh = float(config[f'sample_{i}']['mthresh']) + RA[i], Dec[i], czCMB[i], m_true[i], eta_true[i], m_obs[i], eta_obs[i], xtrue[i] = tfr_create_mock( Nt, L, xmin, cosmo, dens, vel + vbulk, Rmax, alpha, mthresh, a_TFR, b_TFR, sigma_TFR, sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma, sigma_v, interp_order=interp_order, bias_epsilon=bias_epsilon) + all_sigma_m[i] = np.full(Nt, sigma_m) + all_sigma_eta[i] = np.full(Nt, sigma_eta) + all_mthresh[i] = mthresh + elif tracer_type[i] == 'sn': a_tripp = float(config[f'sample_{i}']['a_tripp']) b_tripp = float(config[f'sample_{i}']['b_tripp']) @@ -415,12 +425,17 @@ def borg_mock(s_hat, state, fwd_model, fwd_vel, ini_file, seed=None): hyper_stretch_sigma = float(config[f'sample_{i}']['hyper_stretch_sigma']) hyper_c_mu = float(config[f'sample_{i}']['hyper_c_mu']) hyper_c_sigma = float(config[f'sample_{i}']['hyper_c_sigma']) + RA[i], Dec[i], czCMB[i], m_true[i], stretch_true[i], c_true[i], m_obs[i], stretch_obs[i], c_obs[i], xtrue[i] = sn_create_mock( Nt, L, xmin, cosmo, dens, vel + vbulk, Rmax, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_m, sigma_stretch, sigma_c, hyper_stretch_mu, hyper_stretch_sigma, hyper_c_mu, hyper_c_sigma, sigma_v, interp_order=1, bias_epsilon=1e-7) + all_sigma_m[i] = np.full(Nt, sigma_m) + all_sigma_stretch[i] = np.full(Nt, sigma_stretch) + all_sigma_c[i] = np.full(Nt, sigma_c) + else: raise NotImplementedError - - return tracer_type, RA, Dec, czCMB, m_true, m_obs, eta_true, eta_obs, stretch_true, stretch_obs, c_true, c_obs, xtrue + + return tracer_type, RA, Dec, czCMB, m_true, m_obs, eta_true, eta_obs, stretch_true, stretch_obs, c_true, c_obs, xtrue, all_sigma_m, all_sigma_eta, all_sigma_stretch, all_sigma_c, all_mthresh diff --git a/borg_velocity/likelihood.py b/borg_velocity/likelihood.py index b9b8214..6d6f331 100644 --- a/borg_velocity/likelihood.py +++ b/borg_velocity/likelihood.py @@ -1,18 +1,21 @@ import aquila_borg as borg import jax -import numpy as np +import jaxlib import jax.numpy as jnp +import numpy as np import configparser import h5py import warnings import ast import re from functools import partial +import numbers import borg_velocity.utils as utils from borg_velocity.utils import myprint, generateMBData, compute_As, get_sigma_bulk from borg_velocity.borg_mock import borg_mock import borg_velocity.forwards as forwards +import borg_velocity.tracer_likelihoods as tracer_likelihoods from borg_velocity.samplers import HMCBiasSampler, derive_prior, MVSliceBiasSampler, BlackJaxBiasSampler, TransformedBlackJaxBiasSampler class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood): @@ -40,9 +43,12 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood): # Initialise grid parameters self.N = [int(config['system'][f'N{i}']) for i in range(3)] self.L = [float(config['system'][f'L{i}']) for i in range(3)] + self.corner = [float(config['system'][f'corner{i}']) for i in range(3)] # For log-likelihood values self.bignum = float(config['mcmc']['bignum']) + self.interp_order = int(config['model']['interp_order']) + self.bias_epsilon = float(config['model']['bias_epsilon']) myprint(f" Init {self.N}, {self.L}") super().__init__(fwd, self.N, self.L) @@ -52,16 +58,59 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood): self.fwd_param = param_model self.fwd_vel = fwd_vel + # Number of samples of distance tracers + self.nsamp = int(config['run']['nsamp']) + self.tracer_type = [config[f'sample_{i}']['tracer_type'] for i in range(self.nsamp)] + self.Nt = [int(config[f'sample_{i}']['Nt']) for i in range(self.nsamp)] + # Initialise cosmological parameters cpar = utils.get_cosmopar(self.ini_file) self.fwd.setCosmoParams(cpar) self.fwd_param.setCosmoParams(cpar) self.updateCosmology(cpar) myprint(f"Original cosmological parameters: {self.fwd.getCosmoParams()}") - + # Initialise model parameters - self.sig_v = float(config['model']['sig_v']) - self.bulk_flow = np.array(ast.literal_eval(config['model']['bulk_flow'])) + bulk_flow = np.array(ast.literal_eval(config['model']['bulk_flow'])) + self.model_params = { + 'sig_v': float(config['model']['sig_v']), + 'bulk_flow_x': bulk_flow[0], + 'bulk_flow_y': bulk_flow[1], + 'bulk_flow_z': bulk_flow[2], + } + self.vector_model_params = {} + for i in range(self.nsamp): + tracer_type = config[f'sample_{i}']['tracer_type'] + Nt = int(config[f'sample_{i}']['Nt']) + self.model_params[f'alpha_{i}'] = float(config[f'sample_{i}']['alpha']) + self.vector_model_params[f'm_true_{i}'] = np.zeros(Nt) + self.model_params[f'hyper_mean_m_{i}'] = 0. + self.model_params[f'hyper_sigma_m_{i}'] = 1. + if tracer_type == 'tfr': + self.vector_model_params[f'eta_true_{i}'] = np.zeros(Nt) + self.model_params[f'a_TFR_{i}'] = float(config[f'sample_{i}']['a_tfr']) + self.model_params[f'b_TFR_{i}'] = float(config[f'sample_{i}']['b_tfr']) + self.model_params[f'sigma_TFR_{i}'] = float(config[f'sample_{i}']['sigma_tfr']) + self.model_params[f'hyper_mean_eta_{i}'] = 0. + self.model_params[f'hyper_sigma_eta_{i}'] = 1. + self.model_params[f'hyper_corr_meta_{i}'] = 0. + elif tracer_type == 'sn': + self.vector_model_params[f'stretch_true_{i}'] = np.zeros(Nt) + self.vector_model_params[f'c_true_{i}'] = np.zeros(Nt) + self.model_params[f'a_tripp_{i}'] = float(config[f'sample_{i}']['a_tripp']) + self.model_params[f'b_tripp_{i}'] = float(config[f'sample_{i}']['b_tripp']) + self.model_params[f'M_SN_{i}'] = float(config[f'sample_{i}']['m_sn']) + self.model_params[f'sigma_SN_{i}'] = float(config[f'sample_{i}']['sigma_sn']) + self.model_params[f'hyper_mean_stretch_{i}'] = 0. + self.model_params[f'hyper_sigma_stretch_{i}'] = 1. + self.model_params[f'hyper_mean_c_{i}'] = 0. + self.model_params[f'hyper_sigma_c_{i}'] = 1. + self.model_params[f'hyper_corr_mx_{i}'] = 0. + self.model_params[f'hyper_corr_mc_{i}'] = 0. + self.model_params[f'hyper_corr_xc_{i}'] = 0. + else: + raise ValueError + self.fwd_param.setModelParams(self.model_params) # Initialise derivative self.grad_like = jax.grad(self.dens2like, argnums=(0,1)) @@ -78,9 +127,6 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood): # Seed if creating mocks self.mock_seed = int(config['mock']['seed']) - # Number of samples to distance tracers - self.nsamp = int(config['run']['nsamp']) - # Initialise integration variables if config['model']['R_lim'] == 'none': self.R_lim = self.L[0]/2 @@ -103,6 +149,17 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood): myprint("Init likelihood") state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True) + for i in range(self.nsamp): + if self.tracer_type[i] == 'tfr': + state.newArray1d(f"m_true_{i}", self.Nt[i], True) + state.newArray1d(f"eta_true_{i}", self.Nt[i], True) + elif self.tracer_type[i] == 'sn': + state.newArray1d(f"m_true_{i}", self.Nt[i], True) + state.newArray1d(f"stretch_true_{i}", self.Nt[i], True) + state.newArray1d(f"c_true_{i}", self.Nt[i], True) + else: + raise ValueError(f"Unknown tracer type: {self.tracer_type[i]}") + if self.run_type == 'data': self.loadObservedData() elif borg.EMBEDDED: @@ -162,21 +219,67 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood): - state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood. - make_plot (bool, default=True): Whether to make diagnostic plots for the mock data generation """ + + self.data = {} + self.truths = {} if self.run_type == 'data': raise NotImplementedError elif self.run_type == 'velmass': raise NotImplementedError elif self.run_type == 'mock': - self.tracer_type, self.RA, self.Dec, self.cz_obs, self.m_true, self.m_obs, self.eta_true, self.eta_obs, self.stretch_true, self.stretch_obs, self.c_true, self.c_obs, self.xtrue = borg_mock(s_hat, state, self.fwd, self.fwd_vel, self.ini_file, seed=self.mock_seed) + res = borg_mock(s_hat, state, self.fwd, self.fwd_vel, self.ini_file, seed=self.mock_seed) + self.data['tracer_type'] = res[0] + self.data['RA'] = res[1] + self.data['Dec'] = res[2] + self.data['cz_obs'] = res[3] + self.truths['m_true'] = res[4] + self.data['m_obs'] = res[5] + self.truths['eta_true'] = res[6] + self.data['eta_obs'] = res[7] + self.truths['stretch_true'] = res[8] + self.data['stretch_obs'] = res[9] + self.truths['c_true'] = res[10] + self.data['c_obs'] = res[11] + self.truths['xtrue'] = res[12] + self.data['sigma_m'] = res[13] + self.data['sigma_eta'] = res[14] + self.data['sigma_stretch'] = res[15] + self.data['sigma_c'] = res[16] + self.data['mthresh'] = res[17] else: raise NotImplementedError - + + # Get integration points for i in range(self.nsamp): self.MB_pos[i] = utils.generateMBData( - self.RA[i], self.Dec[i], self.cz_obs[i], self.L[0], self.N[0], - self.R_lim, self.Nsig, self.Nint_points, self.sig_v, self.frac_sigma_r[i]) + self.data['RA'][i], self.data['Dec'][i], self.data['cz_obs'][i], self.L[0], self.N[0], + self.R_lim, self.Nsig, self.Nint_points, self.model_params['sig_v'], self.frac_sigma_r[i]) + # Initialise model parameters + for i in range(self.nsamp): + state[f'm_true_{i}'][:] = self.data['m_obs'][i] + if self.data['tracer_type'][i] == 'tfr': + state[f'eta_true_{i}'][:] = self.data['eta_obs'][i] + self.model_params[f'hyper_mean_m_{i}'] = np.amax(self.data['m_obs'][i]) + self.model_params[f'hyper_sigma_m_{i}'] = np.amax(self.data['m_obs'][i]) - np.percentile(self.data['m_obs'][i], 16) + self.model_params[f'hyper_mean_eta_{i}'] = np.median(self.data['eta_obs'][i]) + self.model_params[f'hyper_sigma_eta_{i}'] = (np.percentile(self.data['eta_obs'][i], 84) - np.percentile(self.data['eta_obs'][i], 16)) / 2 + elif self.data['tracer_type'][i] == 'sn': + state[f'stretch_true_{i}'][:] = self.data['stretch_obs'][i] + state[f'c_true_{i}'][:] = self.data['c_obs'][i] + self.model_params[f'hyper_mean_m_{i}'] = np.median(self.data['m_obs'][i]) + self.model_params[f'hyper_sigma_m_{i}'] = (np.percentile(self.data['m_obs'][i], 84) - np.percentile(self.data['m_obs'][i], 16)) / 2 + self.model_params[f'hyper_mean_stretch_{i}'] = np.median(self.data['stretch_obs'][i]) + self.model_params[f'hyper_sigma_stretch_{i}'] = (np.percentile(self.data['stretch_obs'][i], 84) - np.percentile(self.data['stretch_obs'][i], 16)) / 2 + self.model_params[f'hyper_mean_c_{i}'] = np.median(self.data['c_obs'][i]) + self.model_params[f'hyper_sigma_stretch_{i}'] = (np.percentile(self.data['c_obs'][i], 84) - np.percentile(self.data['c_obs'][i], 16)) / 2 + else: + raise ValueError(f"Unknown tracer type: {self.data['tracer_type'][i]}") + self.fwd_param.setModelParams(self.model_params) + for k in self.vector_model_params.keys(): + self.vector_model_params[k] = state[k] + # Save mock to file # with h5py.File(f'tracer_data_{self.run_type}.h5', 'w') as h5file: # for i in range(self.nsamp): @@ -213,23 +316,43 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood): lkl = 0 - sig_v = self.model_params['sig_v'] + sigma_v = self.model_params['sig_v'] # Compute velocity field - self.bulk_flow = jnp.array([self.model_params['bulk_flow_x'], + bulk_flow = jnp.array([self.model_params['bulk_flow_x'], self.model_params['bulk_flow_y'], self.model_params['bulk_flow_z']]) - v = output_velocity + self.bulk_flow.reshape((3, 1, 1, 1)) + vel = output_velocity + bulk_flow.reshape((3, 1, 1, 1)) omega_m = self.fwd.getCosmoParams().omega_m + h = self.fwd.getCosmoParams().h - raise NotImplementedError + loglike = 0. + for i in range(self.nsamp): + if self.data['tracer_type'][i] == 'tfr': + loglike += tracer_likelihoods.tfr_likelihood( + output_density, vel, self.model_params[f'alpha_{i}'], + self.model_params[f'a_TFR_{i}'], self.model_params[f'b_TFR_{i}'], self.model_params[f'sigma_TFR_{i}'], sigma_v, + self.vector_model_params[f'm_true_{i}'], self.vector_model_params[f'eta_true_{i}'], + omega_m, h, self.L[0], self.corner[0], self.interp_order, self.bias_epsilon, + self.data['cz_obs'][i], self.data['m_obs'][i], self.data['eta_obs'][i], + self.data['sigma_m'][i], self.data['sigma_eta'][i], self.MB_pos[i], self.data['mthresh'][i]) + elif self.data['tracer_type'][i] == 'sn': + loglike += tracer_likelihoods.sn_likelihood( + output_density, vel, self.model_params[f'alpha_{i}'], + self.model_params[f'a_tripp_{i}'], self.model_params[f'b_tripp_{i}'], + self.model_params[f'M_SN_{i}'], self.model_params[f'sigma_SN_{i}'], sigma_v, + self.vector_model_params[f'm_true_{i}'], self.vector_model_params[f'stretch_true_{i}'], self.vector_model_params[f'c_true_{i}'], + omega_m, h, self.L[0], self.corner[0], self.interp_order, self.bias_epsilon, + self.data['cz_obs'][i], self.data['m_obs'][i], self.data['stretch_obs'][i], self.data['c_obs'][i], + self.data['sigma_m'][i], self.data['sigma_stretch'][i], self.data['sigma_c'][i], self.MB_pos[i]) + else: + raise ValueError(f"Unknown tracer type: {self.data['tracer_type'][i]}") # Add in bulk flow prior - lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(self.sigma_bulk / jnp.sqrt(3)) + self.bulk_flow ** 2 / 2. / (self.sigma_bulk / jnp.sqrt(3)) ** 2) + lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(self.sigma_bulk / jnp.sqrt(3)) + bulk_flow ** 2 / 2. / (self.sigma_bulk / jnp.sqrt(3)) ** 2) - # lkl = jnp.clip(lkl, -self.bignum, self.bignum) - lkl = lax.cond( + lkl = jax.lax.cond( jnp.isfinite(lkl), lambda _: lkl, # If True (finite), return lkl lambda _: self.bignum, # If False (not finite), return self.bignum @@ -697,12 +820,14 @@ def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.L """ To Do: - Add in the true colours etc. to the model parameter sampling -- Call to the correct log-likelihoods - Sample the variables from priors - Check global variables are used -- Create minimal ini file - Docstrings - Don't want all these samplers - only gradient ones - Save mock to file +- How to sample the true values now they are not in model_params? +- How too update the vector of true values for likelihood call? +- Commit true values to file +- Check changes to sigma_m etc. likelihoods in original tests """ diff --git a/borg_velocity/tracer_likelihoods.py b/borg_velocity/tracer_likelihoods.py new file mode 100644 index 0000000..65672e0 --- /dev/null +++ b/borg_velocity/tracer_likelihoods.py @@ -0,0 +1,377 @@ +import jax +import jax.numpy as jnp + +import borg_velocity.projection as projection +import borg_velocity.utils as utils + +def tfr_likelihood_vel(dens, vel, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, + omega_m, h, L, xmin, interp_order, bias_epsilon, + cz_obs, MB_pos, mthresh): + """ + Evaluate the terms in the likelihood from the velocity and malmquist bias for TFR tracers + + Args: + - alpha (float): Exponent for bias model + - a_TFR (float): TFR relation intercept + - b_TFR (float): TFR relation slope + - sigma_TFR (float): Intrinsic scatter in the TFR + - sigma_v (float): Uncertainty on the velocity field (km/s) + - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) + - eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,)) + - dens (np.ndarray): Over-density field (shape = (N, N, N)) + - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) + - omega_m (float): Matter density parameter Om + - h (float): Hubble constant H0 = 100 h km/s/Mpc + - L (float): Comoving box size (Mpc/h) + - xmin (float): Coordinate of corner of the box (Mpc/h) + - interp_order (int): Order of interpolation from grid points to the line of sight + - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# + - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) + - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). + The shape is (3, Nt, Nsig) + - mthresh (float): Threshold absolute magnitude in selection + + Returns: + - loglike (float): The log-likelihood of the data + """ + + # Comoving radii of integration points (Mpc/h) + r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0)) + + # p_r = r^2 n(r) N(mutrue; muTFR, sigmaTFR) + # Multiply by arbitrary number for numerical stability (cancels in p_r / p_r_norm) + number_density = projection.interp_field( + dens, + MB_pos, + L, + jnp.array([xmin, xmin, xmin]), + interp_order, + use_jitted=True, + ) + number_density = jax.nn.relu(1. + number_density) + number_density = jnp.power(number_density + bias_epsilon, alpha) + zcosmo = utils.z_cos(r, omega_m) + mutrue = 5 * jnp.log10((1 + zcosmo) * r / h) + 25 + muTFR = m_true - (a_TFR + b_TFR * eta_true) + d2 = ((mutrue - muTFR[:,None]) / sigma_TFR) ** 2 + best = jnp.amin(jnp.abs(d2), axis=1) + d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1) + p_r = r ** 2 * jnp.exp(-0.5 * d2) * number_density + p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1) + + # Peculiar velocity term + tracer_vel = projection.interp_field( + vel, + MB_pos, + L, + jnp.array([xmin, xmin, xmin]), + interp_order, + use_jitted=True, + ) + tracer_vr = projection.project_radial( + tracer_vel, + MB_pos, + jnp.zeros(3,) + ) + cz_pred = ((1 + zcosmo) * (1 + tracer_vr / utils.speed_of_light) - 1) * utils.speed_of_light + d2 = ((cz_pred - jnp.expand_dims(cz_obs, axis=1)) / sigma_v)**2 + scale = jnp.nanmin(d2, axis=1) + d2 = d2 - jnp.expand_dims(scale, axis=1) + + # Integrate to get likelihood + p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1) + lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * jnp.pi * sigma_v**2) + loglike = lkl_ind.sum() + + return loglike + + +def tfr_likelihood_m(m_true, m_obs, sigma_m, mthresh): + """ + Evaluate the terms in the likelihood from apparent magnitude for TFR tracers + + Args: + - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) + - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - mthresh (float): Threshold absolute magnitude in selection + + Returns: + - loglike (float): The log-likelihood of the data + """ + + Nt = m_obs.shape[0] + norm = ( + jnp.log(2) + - jnp.log(jax.scipy.special.erfc(- (mthresh - m_true) / (jnp.sqrt(2) * sigma_m))) + - 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2) + ) + loglike = - ( + jnp.sum(0.5 * (m_obs - m_true) ** 2 / sigma_m ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)) + - jnp.sum(norm) + ) + + return loglike + + +def tfr_likelihood_eta(eta_true, eta_obs, sigma_eta): + """ + Evaluate the terms in the likelihood from linewidth for TFR tracers + + Args: + - eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,)) + - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) + - sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements + + Returns: + - loglike (float): The log-likelihood of the data + """ + + loglike = - ( + jnp.sum(0.5 * (eta_obs - eta_true) ** 2 / sigma_eta ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2)) + ) + + return loglike + + +def tfr_likelihood(dens, vel, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, + omega_m, h, L, xmin, interp_order, bias_epsilon, + cz_obs, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh): + """ + Evaluate the likelihood for TFR sample + + Args: + - dens (np.ndarray): Over-density field (shape = (N, N, N)) + - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) + - alpha (float): Exponent for bias model + - a_TFR (float): TFR relation intercept + - b_TFR (float): TFR relation slope + - sigma_TFR (float): Intrinsic scatter in the TFR + - sigma_v (float): Uncertainty on the velocity field (km/s) + - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) + - eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,)) + - omega_m (float): Matter density parameter Om + - h (float): Hubble constant H0 = 100 h km/s/Mpc + - L (float): Comoving box size (Mpc/h) + - xmin (float): Coordinate of corner of the box (Mpc/h) + - interp_order (int): Order of interpolation from grid points to the line of sight + - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# + - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) + - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) + - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - sigma_eta (float or np.ndarray): Uncertainty on the linewidth measurements + - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). + The shape is (3, Nt, Nsig) + - mthresh (float): Threshold absolute magnitude in selection + + Returns: + - loglike (float): The log-likelihood of the data + """ + + + loglike_vel = tfr_likelihood_vel(dens, vel, + alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, + omega_m, h, L, xmin, interp_order, bias_epsilon, + cz_obs, MB_pos, mthresh) + loglike_m = tfr_likelihood_m(m_true, m_obs, sigma_m, mthresh) + loglike_eta = tfr_likelihood_eta(eta_true, eta_obs, sigma_eta) + + loglike = (loglike_vel + loglike_m + loglike_eta) + + return loglike + + +def sn_likelihood_vel(dens, vel, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, + omega_m, h, L, xmin, interp_order, bias_epsilon, + cz_obs, MB_pos): + """ + Evaluate the terms in the likelihood from the velocity and malmquist bias for supernovae + + Args: + - dens (np.ndarray): Over-density field (shape = (N, N, N)) + - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) + - alpha (float): Exponent for bias model + - a_tripp (float): Coefficient of stretch in the Tripp relation + - b_tripp (float): Coefficient of colour in the Tripp relation + - M_SN (float): Absolute magnitude of supernovae + - sigma_SN (float): Intrinsic scatter in the Tripp relation + - sigma_v (float): Uncertainty on the velocity field (km/s) + - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) + - stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,)) + - c_true (np.ndarray): True colour values of the tracers (shape = (Nt,)) + - omega_m (float): Matter density parameter Om + - h (float): Hubble constant H0 = 100 h km/s/Mpc + - L (float): Comoving box size (Mpc/h) + - xmin (float): Coordinate of corner of the box (Mpc/h) + - interp_order (int): Order of interpolation from grid points to the line of sight + - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# + - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) + - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). + The shape is (3, Nt, Nsig) + + Returns: + - loglike (float): The log-likelihood of the data + """ + + # Comoving radii of integration points (Mpc/h) + r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0)) + + # p_r = r^2 n(r) N(mutrue; muTFR, sigmaTFR) + # Multiply by arbitrary number for numerical stability (cancels in p_r / p_r_norm) + number_density = projection.interp_field( + dens, + MB_pos, + L, + jnp.array([xmin, xmin, xmin]), + interp_order, + use_jitted=True, + ) + number_density = jax.nn.relu(1. + number_density) + number_density = jnp.power(number_density + bias_epsilon, alpha) + zcosmo = utils.z_cos(r, omega_m) + mutrue = 5 * jnp.log10((1 + zcosmo) * r / h) + 25 + mutripp = m_true + a_tripp * stretch_true - b_tripp * c_true - M_SN + d2 = ((mutrue - mutripp[:,None]) / sigma_SN) ** 2 + best = jnp.amin(jnp.abs(d2), axis=1) + d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1) + p_r = r ** 2 * jnp.exp(-0.5 * d2) * number_density + p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1) + + # Peculiar velocity term + tracer_vel = projection.interp_field( + vel, + MB_pos, + L, + jnp.array([xmin, xmin, xmin]), + interp_order, + use_jitted=True, + ) + tracer_vr = projection.project_radial( + tracer_vel, + MB_pos, + jnp.zeros(3,) + ) + cz_pred = ((1 + zcosmo) * (1 + tracer_vr / utils.speed_of_light) - 1) * utils.speed_of_light + d2 = ((cz_pred - jnp.expand_dims(cz_obs, axis=1)) / sigma_v)**2 + scale = jnp.nanmin(d2, axis=1) + d2 = d2 - jnp.expand_dims(scale, axis=1) + + # Integrate to get likelihood + p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1) + lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * jnp.pi * sigma_v**2) + loglike = lkl_ind.sum() + + return loglike + + +def sn_likelihood_stretch(stretch_true, stretch_obs, sigma_stretch): + """ + Evaluate the terms in the likelihood from stretch for supernovae + + Args: + - stretch_true (np.ndarray): True stretch of the tracers (shape = (Nt,)) + - stretch_obs (np.ndarray): Observed stretch of the tracers (shape = (Nt,)) + - sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements + + Returns: + - loglike (float): The log-likelihood of the data + """ + + loglike = - ( + jnp.sum(0.5 * (stretch_obs - stretch_true) ** 2 / sigma_stretch ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_stretch ** 2)) + ) + + return loglike + + +def sn_likelihood_c(c_true, c_obs, sigma_c): + """ + Evaluate the terms in the likelihood from colour for supernovae + + Args: + - c_true (np.ndarray): True colours of the tracers (shape = (Nt,)) + - c_obs (np.ndarray): Observed colours of the tracers (shape = (Nt,)) + - sigma_c (float or np.ndarray): Uncertainty on the colours measurements + + Returns: + - loglike (float): The log-likelihood of the data + """ + + loglike = - ( + jnp.sum(0.5 * (c_obs - c_true) ** 2 / sigma_c ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_c ** 2)) + ) + + return loglike + + +def sn_likelihood_m(m_true, m_obs, sigma_m): + """ + Evaluate the terms in the likelihood from apparent magntiude for supernovae + + Args: + - m_true (np.ndarray): True apparent magnitude of the tracers (shape = (Nt,)) + - m_obs (np.ndarray): Observed apparent magnitude of the tracers (shape = (Nt,)) + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + + Returns: + - loglike (float): The log-likelihood of the data + """ + + loglike = - ( + jnp.sum(0.5 * (m_obs - m_true) ** 2 / sigma_m ** 2 + 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2)) + ) + + return loglike + + +def sn_likelihood(dens, vel, alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, + omega_m, h, L, xmin, interp_order, bias_epsilon, + cz_obs, m_obs, stretch_obs, c_obs, sigma_m, sigma_stretch, sigma_c, MB_pos): + """ + Evaluate the likelihood for SN sample + + Args: + - dens (np.ndarray): Over-density field (shape = (N, N, N)) + - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) + - alpha (float): Exponent for bias model + - a_tripp (float): Coefficient of stretch in the Tripp relation + - b_tripp (float): Coefficient of colour in the Tripp relation + - M_SN (float): Absolute magnitude of supernovae + - sigma_SN (float): Intrinsic scatter in the Tripp relation + - sigma_v (float): Uncertainty on the velocity field (km/s) + - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) + - stretch_true (np.ndarray): True stretch values of the tracers (shape = (Nt,)) + - c_true (np.ndarray): True colour values of the tracers (shape = (Nt,)) + - omega_m (float): Matter density parameter Om + - h (float): Hubble constant H0 = 100 h km/s/Mpc + - L (float): Comoving box size (Mpc/h) + - xmin (float): Coordinate of corner of the box (Mpc/h) + - interp_order (int): Order of interpolation from grid points to the line of sight + - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# + - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) + - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) + - stretch_obs (np.ndarray): Observed stretch values of the tracers (shape = (Nt,)) + - c_obs (np.ndarray): Observed colour values of the tracers (shape = (Nt,)) + - sigma_m (float or np.ndarray): Uncertainty on the apparent magnitude measurements + - sigma_stretch (float or np.ndarray): Uncertainty on the stretch measurements + - sigma_c (float or np.ndarray): Uncertainty on the colour measurements + - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). + The shape is (3, Nt, Nsig) + + Returns: + - loglike (float): The log-likelihood of the data + """ + + + loglike_vel = sn_likelihood_vel(dens, vel, + alpha, a_tripp, b_tripp, M_SN, sigma_SN, sigma_v, m_true, stretch_true, c_true, + omega_m, h, L, xmin, interp_order, bias_epsilon, + cz_obs, MB_pos) + loglike_stretch = sn_likelihood_stretch(stretch_true, stretch_obs, sigma_stretch) + loglike_c = sn_likelihood_c(c_true, c_obs, sigma_c) + loglike_m = sn_likelihood_m(m_true, m_obs, sigma_m) + + loglike = loglike_vel + loglike_stretch + loglike_c + loglike_m + + return loglike \ No newline at end of file