import aquila_borg as borg import numpy as np from astropy.coordinates import SkyCoord import astropy.units as apu import astropy.constants import pandas as pd import jax.numpy as jnp import jax.scipy.special import matplotlib.pyplot as plt import corner import numpyro import numpyro.distributions as dist from jax import lax, random import borg_velocity.poisson_process as poisson_process import borg_velocity.projection as projection import borg_velocity.utils as utils # Output stream management cons = borg.console() myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x)) def build_gravity_model(box, cpar, ai=0.05, af=1.0, nsteps=20, forcesampling=4, supersampling=2, rsmooth=4.0, gravity='lpt', velmodel_name='CICModel'): """ Builds the gravity model and returns the forward model chain. Args: - box (borg.forward.BoxModel): Box within which to run simulation - cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters to use - ai (float, default=0.05): Scale factor to begin simulation - af (float, default=1.0): Scale factor to end simulation - nsteps (int, default=20): Number of steps to use in the simulation - forcesampling (int, default=4): Sampling factor for force evaluations - supersampling (int, default=2): Supersampling factor of particles - rsmooth (float, default=4.0): Smoothing scale for velocity field (Mpc/h) - gravity (str, default='lpt'): Which gravity model to use - velmodel_name (str, default='CICModel'): Which velocity estimator to use Returns: - chain (borg.forward.BaseForwardModel): The forward model for density - fwd_vel (borg.forward.VelocityBase): The forward model for velocity """ myprint(f"Building gravity model {gravity}") # Setup forward model chain = borg.forward.ChainForwardModel(box) chain.addModel(borg.forward.models.HermiticEnforcer(box)) # CLASS transfer function chain @= borg.forward.model_lib.M_PRIMORDIAL_AS(box) transfer_class = borg.forward.model_lib.M_TRANSFER_CLASS(box, opts=dict(a_transfer=1.0)) transfer_class.setModelParams({"extra_class_arguments":{'YHe':'0.24'}}) chain @= transfer_class if gravity == 'linear': raise NotImplementedError(gravity) elif gravity == 'lpt': mod = borg.forward.model_lib.M_LPT_CIC( box, opts=dict(a_initial=af, a_final=af, do_rsd=False, supersampling=supersampling, lightcone=False, part_factor=1.01,)) elif gravity == '2lpt': mod = borg.forward.model_lib.M_2LPT_CIC( box, opts=dict(a_initial=af, a_final=af, do_rsd=False, supersampling=supersampling, lightcone=False, part_factor=1.01,)) elif gravity == 'pm': mod = borg.forward.model_lib.M_PM_CIC( box, opts=dict(a_initial=af, a_final=af, do_rsd=False, supersampling=supersampling, part_factor=1.01, forcesampling=forcesampling, pm_start_z=1/ai - 1, pm_nsteps=nsteps, tcola=False)) elif gravity == 'cola': mod = borg.forward.model_lib.M_PM_CIC( box, opts=dict(a_initial=af, a_final=af, do_rsd=False, supersampling=supersampling, part_factor=1.01, forcesampling=forcesampling, pm_start_z=1/ai - 1, pm_nsteps=nsteps, tcola=True)) else: raise NotImplementedError(gravity) mod.accumulateAdjoint(True) chain @= mod # Cosmological parameters chain.setCosmoParams(cpar) # This is the forward model for velocity velmodel = getattr(borg.forward.velocity, velmodel_name) if velmodel_name == 'LinearModel': fwd_vel = velmodel(box, mod, af) elif velmodel_name == 'CICModel': fwd_vel = velmodel(box, mod, rsmooth) else: fwd_vel = velmodel(box, mod) return chain, fwd_vel def get_fields(L, N, xmin, gravity='lpt', velmodel_name='CICModel'): """ Obtain a density and velocity field to use for mock Args: - L (float): Box length (Mpc/h) - N (int): Number of grid cells per side - xmin (float): Coordinate of corner of the box (Mpc/h) - gravity (str, default='lpt'): Which gravity model to use - velmodel_name (str, default='CICModel'): Which velocity estimator to use Returns: - cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters to use - output_density (np.ndarray): Over-density field - output_velocity (np.ndarray): Velocity field """ # Setup box and cosmology cpar = borg.cosmo.CosmologicalParameters() cpar.default() box = borg.forward.BoxModel() box.L = (L, L, L) box.N = (N, N, N) box.xmin = (xmin, xmin, xmin) # Get forward models fwd_model, fwd_vel = build_gravity_model(box, cpar, gravity=gravity, velmodel_name=velmodel_name) # Make some initial conditions s_hat = np.fft.rfftn(np.random.randn(*box.N)) / box.Ntot ** (0.5) # Obtain density and velocity fields output_density = np.zeros(box.N) fwd_model.forwardModel_v2(s_hat) fwd_model.getDensityFinal(output_density) output_velocity = fwd_vel.getVelocityField() return cpar, output_density, output_velocity def create_mock(Nt, L, xmin, cpar, dens, vel, Rmax, alpha, mthresh, a_TFR, b_TFR, sigma_TFR, sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma, sigma_v, interp_order=1, bias_epsilon=1e-7): # Initialize lists to store valid positions and corresponding sig_mu values all_xtrue = np.empty((3, Nt)) all_mtrue = np.empty(Nt) all_etatrue = np.empty(Nt) all_mobs = np.empty(Nt) all_etaobs = np.empty(Nt) all_RA = np.empty(Nt) all_Dec = np.empty(Nt) # Counter for accepted positions accepted_count = 0 # Bias model phi = (1. + dens + bias_epsilon) ** alpha # Only use centre of box x = np.linspace(xmin, xmin + L, dens.shape[0]+1) i0 = np.argmin(np.abs(x + Rmax)) i1 = np.argmin(np.abs(x - Rmax)) L_small = x[i1] - x[i0] xmin_small = x[i0] phi_small = phi[i0:i1, i0:i1, i0:i1] # Loop until we have Nt valid positions while accepted_count < Nt: # Generate positions (comoving) xtrue = poisson_process.sample_3d(phi_small, Nt, L_small, (xmin_small, xmin_small, xmin_small)) # Convert to RA, Dec, Distance (comoving) rtrue = np.sqrt(np.sum(xtrue** 2, axis=0)) # Mpc/h c = SkyCoord(x=xtrue[0], y=xtrue[1], z=xtrue[2], representation_type='cartesian') RA = c.spherical.lon.degree Dec = c.spherical.lat.degree r_hat = np.array(SkyCoord(ra=RA*apu.deg, dec=Dec*apu.deg).cartesian.xyz) # Compute cosmological redshift zcosmo = utils.z_cos(rtrue, cpar.omega_m) # Compute luminosity distance # DO I NEED TO DO /h??? dL = (1 + zcosmo) * rtrue / cpar.h # Mpc # Compute true distance modulus mutrue = 5 * np.log10(dL) + 25 # Sample true linewidth (eta) from its prior etatrue = hyper_eta_mu + hyper_eta_sigma * np.random.randn(Nt) # Obtain muTFR from mutrue using the intrinsic scatter muTFR = mutrue + sigma_TFR * np.random.randn(Nt) # Obtain apparent magnitude from the TFR mtrue = muTFR + (a_TFR + b_TFR * etatrue) # Scatter true observed apparent magnitudes and linewidths mobs = mtrue + sigma_m * np.random.randn(Nt) etaobs = etatrue + sigma_eta * np.random.randn(Nt) # Apply apparement magnitude cut m = mobs <= mthresh mtrue = mtrue[m] etatrue = etatrue[m] mobs = mobs[m] etaobs = etaobs[m] xtrue = xtrue[:,m] RA = RA[m] Dec = Dec[m] # Calculate how many valid positions we need to reach Nt remaining_needed = Nt - accepted_count selected_count = min(xtrue.shape[1], remaining_needed) # Append only the needed number of valid positions imin = accepted_count imax = accepted_count + selected_count all_xtrue[:,imin:imax] = xtrue[:,:selected_count] all_mtrue[imin:imax] = mtrue[:selected_count] all_etatrue[imin:imax] = etatrue[:selected_count] all_mobs[imin:imax] = mobs[:selected_count] all_etaobs[imin:imax] = etaobs[:selected_count] all_RA[imin:imax] = RA[:selected_count] all_Dec[imin:imax] = Dec[:selected_count] # Update the accepted count accepted_count += selected_count myprint(f'\tMade {accepted_count} of {Nt}') # Get the radial component of the peculiar velocity at the positions of the objects myprint('Obtaining peculiar velocities') tracer_vel = projection.interp_field( vel, np.expand_dims(all_xtrue, axis=2), L, np.array([xmin, xmin, xmin]), interp_order ) # km/s myprint('Radial projection') vr_true = np.squeeze(projection.project_radial( tracer_vel, np.expand_dims(all_xtrue, axis=2), np.zeros(3,) )) # km/s # Recompute cosmological redshift rtrue = jnp.sqrt(jnp.sum(all_xtrue ** 2, axis=0)) zcosmo = utils.z_cos(rtrue, cpar.omega_m) # Obtain total redshift vr_noised = vr_true + sigma_v * np.random.randn(Nt) czCMB = ((1 + zcosmo) * (1 + vr_noised / utils.speed_of_light) - 1) * utils.speed_of_light return all_RA, all_Dec, czCMB, all_mtrue, all_etatrue, all_mobs, all_etaobs, all_xtrue def estimate_data_parameters(): """ ID 2MASS XSC ID name (HHMMSSss+DDMMSSs) RAdeg Right ascension (J2000) DEdeg Declination (J2000) cz2mrs Heliocentric redshift from the 2MRS (km/s) Kmag NIR magnitudes in the K band from the 2MRS (mag) Hmag NIR magnitudes in the H band from the 2MRS (mag) Jmag NIR magnitudes in the J band from the 2MRS (mag) e_Kmag Error of the NIR magnitudes in K band from the (mag) e_Hmag Error of the NIR magnitudes in H band from the (mag) e_Jmag Error of the NIR magnitudes in J band from the (mag) WHIc Corrected HI width (km/s) e_WHIc Error of corrected HI width (km/s) """ columns = ['ID', 'RAdeg', 'DEdeg', 'cz2mrs', 'Kmag', 'Hmag', 'Jmag', 'e_Kmag', 'e_Hmah', 'e_Jmag', 'WHIc', 'e_WHIc'] fname = '/data101/bartlett/fsigma8/PV_data/2MASS/table1.dat' df = pd.read_csv(fname, sep='\s+', names=columns) sigma_m = np.median(df['e_Kmag']) eta = np.log10(df['WHIc']) - 2.5 sigma_eta = np.median(df['e_WHIc'] / df['WHIc'] / np.log(10)) hyper_eta_mu = np.median(eta) hyper_eta_sigma = (np.percentile(eta, 84) - np.percentile(eta, 16)) / 2 return sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma def generateMBData(RA, Dec, cz_obs, L, N, R_lim, Nsig, Nint_points, sigma_v, frac_sigma_r): """ Generate points along the line of sight of each tracer to enable marginalisation over distance uncertainty. The central distance is given such that the observed redshift equals the cosmological redshift at this distance. The range is then +/- Nsig * sig, where sig^2 = (sig_v/100)^2 + sig_r^2 and sig_v is the velocity uncertainty in km/s Args: - RA (np.ndarray): Right Ascension (degrees) of the tracers (shape = (Nt,)) - Dec (np.ndarray): Delination (degrees) of the tracers (shape = (Nt,)) - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) - L (float): Box length (Mpc/h) - N (int): Number of grid cells per side - R_lim (float): Maximum allowed (true) comoving distance of a tracer (Mpc/h) - Nsig (float): ??? - Nint_points (int): Number of radii over which to integrate the likelihood - sigma_v (float): Uncertainty on the velocity field (km/s) - frac_sigma_r (float): An estimate of the fractional uncertainty on the positions of tracers Returns: - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) """ myprint(f"Making MB data") # Convert RA, DEC to radial vector r_hat = np.array(SkyCoord(ra=RA*apu.deg, dec=Dec*apu.deg).cartesian.xyz).T # Get min and max distance to integrate over # cz = 100 h r, so sigma_v corresponds to a sigma_r of ~ sigma_v / 100 robs = cz_obs / 100 sigr = np.sqrt((sigma_v / 100) ** 2 + (frac_sigma_r * robs)**2) rmin = robs - Nsig * sigr rmin = rmin.at[rmin <= 0].set(L / N / 100.) rmax = robs + Nsig * sigr rmax = rmax.at[rmax > R_lim].set(R_lim) # Compute coordinates of integration points r_integration = np.linspace(rmin, rmax, Nint_points) MB_pos = np.expand_dims(r_integration, axis=2) * r_hat[None,...] MB_pos = jnp.transpose(MB_pos, (2, 1, 0)) return MB_pos def likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, cz_obs, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh): """ Evaluate the likelihood for TFR sample Args: - alpha (float): Exponent for bias model - a_TFR (float): TFR relation intercept - b_TFR (float): TFR relation slope - sigma_TFR (float): Intrinsic scatter in the TFR - sigma_v (float): Uncertainty on the velocity field (km/s) - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) - eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,)) - dens (np.ndarray): Over-density field (shape = (N, N, N)) - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) - omega_m (float): Matter density parameter Om - h (float): Hubble constant H0 = 100 h km/s/Mpc - L (float): Comoving box size (Mpc/h) - xmin (float): Coordinate of corner of the box (Mpc/h) - interp_order (int): Order of interpolation from grid points to the line of sight - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) - sigma_m (float): Uncertainty on the apparent magnitude measurements - sigma_eta (float): Uncertainty on the apparent linewidth measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) - mthresh (float): Threshold absolute magnitude in selection Returns: - loglike (float): The log-likelihood of the data """ # Comoving radii of integration points (Mpc/h) r = jnp.sqrt(jnp.sum(MB_pos ** 2, axis=0)) # p_r = r^2 n(r) N(mutrue; muTFR, sigmaTFR) # Multiply by arbitrary number for numerical stability (cancels in p_r / p_r_norm) number_density = projection.interp_field( dens, MB_pos, L, jnp.array([xmin, xmin, xmin]), interp_order, use_jitted=True, ) number_density = jax.nn.relu(1. + number_density) number_density = jnp.power(number_density + bias_epsilon, alpha) zcosmo = utils.z_cos(r, omega_m) mutrue = 5 * jnp.log10((1 + zcosmo) * r / h) + 25 muTFR = m_true - (a_TFR + b_TFR * eta_true) d2 = ((mutrue - muTFR[:,None]) / sigma_TFR) ** 2 best = jnp.amin(jnp.abs(d2), axis=1) d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1) p_r = r ** 2 * jnp.exp(-0.5 * d2) * number_density p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1) # Peculiar velocity term tracer_vel = projection.interp_field( vel, MB_pos, L, jnp.array([xmin, xmin, xmin]), interp_order, use_jitted=True, ) tracer_vr = projection.project_radial( tracer_vel, MB_pos, jnp.zeros(3,) ) cz_pred = ((1 + zcosmo) * (1 + tracer_vr / utils.speed_of_light) - 1) * utils.speed_of_light d2 = ((cz_pred - jnp.expand_dims(cz_obs, axis=1)) / sigma_v)**2 scale = jnp.nanmin(d2, axis=1) d2 = d2 - jnp.expand_dims(scale, axis=1) # Integrate to get likelihood p_cz = jnp.trapezoid(jnp.exp(-0.5 * d2) * p_r / p_r_norm, r, axis=1) lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * np.pi * sigma_v**2) loglike_vel = - lkl_ind.sum() Nt = m_obs.shape[0] # Apparent magnitude terms norm = 0.5 * (1 + jax.scipy.special.erf((mthresh - m_true) / (jnp.sqrt(2) * sigma_m))) loglike_m = ( 0.5 * jnp.sum((m_obs - m_true) ** 2 / sigma_m ** 2) + jnp.sum(jnp.log(norm)) + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_m ** 2) ) # Linewidth terms loglike_eta = ( 0.5 * jnp.sum((eta_obs - eta_true) ** 2 / sigma_eta ** 2) + Nt * 0.5 * jnp.log(2 * jnp.pi * sigma_eta ** 2) ) # loglike = - (loglike_vel + loglike_m + loglike_eta) loglike = - (loglike_eta + loglike_m) return loglike def test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh): """ Plot likelihood as we scan through the paramaters [alpha, a_TFR, b_TFR, sigma_TFR, sigma_v] to verify that the likelihood shape looks reasonable Args: - prior (dict): Upper and lower bounds for a uniform prior for the parameters - alpha (float): Exponent for bias model - a_TFR (float): TFR relation intercept - b_TFR (float): TFR relation slope - sigma_TFR (float): Intrinsic scatter in the TFR - sigma_v (float): Uncertainty on the velocity field (km/s) - m_true (np.ndarray): True apparent magnitudes of the tracers (shape = (Nt,)) - eta_true (np.ndarray): True linewidths of the tracers (shape = (Nt,)) - dens (np.ndarray): Over-density field (shape = (N, N, N)) - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) - omega_m (float): Matter density parameter Om - h (float): Hubble constant H0 = 100 h km/s/Mpc - L (float): Comoving box size (Mpc/h) - xmin (float): Coordinate of corner of the box (Mpc/h) - interp_order (int): Order of interpolation from grid points to the line of sight - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) - sigma_m (float): Uncertainty on the apparent magnitude measurements - sigma_eta (float): Uncertainty on the apparent linewidth measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) - mthresh (float): Threshold absolute magnitude in selection """ pars = [alpha, a_TFR, b_TFR, sigma_TFR, sigma_v] par_names = ['alpha', 'a_TFR', 'b_TFR', 'sigma_TFR', 'sigma_v'] orig_ll = likelihood(*pars, m_true, eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh) for i, name in enumerate(par_names): myprint(f'Scanning {name}') if name in prior: x = np.linspace(*prior[name], 20) else: pmin = pars[i] * 0.2 pmax = pars[i] * 2.0 x = np.linspace(pmin, pmax, 20) all_ll = np.empty(x.shape) orig_x = pars[i] for j, xx in enumerate(x): pars[i] = xx all_ll[j] = likelihood(*pars, m_true, eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh) pars[i] = orig_x plt.figure() plt.plot(x, all_ll, '.') plt.axvline(orig_x, ls='--', color='k') plt.axhline(orig_ll, ls='--', color='k') plt.xlabel(name) plt.ylabel('Negative log-likelihood') plt.savefig(f'likelihood_scan_{name}.png') fig = plt.gcf() plt.clf() plt.close(fig) return def run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh, m_true): """ Run MCMC over the model parameters Args: - num_warmup (int): Number of warmup steps to take in the MCMC - num_samples (int): Number of samples to take in the MCMC - prior - initial - dens (np.ndarray): Over-density field (shape = (N, N, N)) - vel (np.ndarray): Velocity field (km/s) (shape = (3, N, N, N)) - omega_m (float): Matter density parameter Om - h (float): Hubble constant H0 = 100 h km/s/Mpc - L (float): Comoving box size (Mpc/h) - xmin (float): Coordinate of corner of the box (Mpc/h) - interp_order (int): Order of interpolation from grid points to the line of sight - bias_epsilon (float): Small number to add to 1 + delta to prevent 0^# - cz_obs (np.ndarray): Observed redshifts (km/s) of the tracers (shape = (Nt,)) - m_obs (np.ndarray): Observed apparent magnitudes of the tracers (shape = (Nt,)) - eta_obs (np.ndarray): Observed linewidths of the tracers (shape = (Nt,)) - sigma_m (float): Uncertainty on the apparent magnitude measurements - sigma_eta (float): Uncertainty on the apparent linewidth measurements - MB_pos (np.ndarray): Comoving coordinates of integration points to use in likelihood (Mpc/h). The shape is (3, Nt, Nsig) - mthresh (float): Threshold absolute magnitude in selection """ Nt = eta_obs.shape[0] def tfr_model(): alpha = numpyro.sample("alpha", dist.Uniform(*prior['alpha'])) a_TFR = numpyro.sample("a_TFR", dist.Uniform(*prior['a_TFR'])) b_TFR = numpyro.sample("b_TFR", dist.Uniform(*prior['b_TFR'])) sigma_TFR = numpyro.sample("sigma_TFR", dist.HalfCauchy(1.0)) sigma_v = numpyro.sample("sigma_v", dist.HalfCauchy(1.0)) # # Sample the means with a uniform prior # hyper_mean_m = numpyro.sample("hyper_mean_m", dist.Uniform(*prior['hyper_mean_m'])) # hyper_mean_eta = numpyro.sample("hyper_mean_eta", dist.Uniform(*prior['hyper_mean_eta'])) # hyper_mean = jnp.array([hyper_mean_m, hyper_mean_eta]) # # Sample standard deviations with a 1/sigma prior (Jeffreys prior approximation) # hyper_sigma_m = numpyro.sample("hyper_sigma_m", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior # hyper_sigma_eta = numpyro.sample("hyper_sigma_eta", dist.HalfCauchy(1.0)) # hyper_sigma = jnp.array([hyper_sigma_m, hyper_sigma_eta]) # # Sample correlation matrix using LKJ prior # L_corr = numpyro.sample("L_corr", dist.LKJCholesky(2, concentration=1.0)) # Cholesky factor of correlation matrix # corr_matrix = L_corr @ L_corr.T # Convert to full correlation matrix # # Construct full covariance matrix: Σ = D * Corr * D # hyper_cov = jnp.diag(hyper_sigma) @ corr_matrix @ jnp.diag(hyper_sigma) # # Sample the true eta and m # x = numpyro.sample("x", dist.MultivariateNormal(hyper_mean, hyper_cov), sample_shape=(Nt,)) # m_true = numpyro.deterministic("m_true", x[:, 0]) # eta_true = numpyro.deterministic("eta_true", x[:, 1]) hyper_mean_eta = numpyro.sample("hyper_mean_eta", dist.Uniform(*prior['hyper_mean_eta'])) hyper_sigma_eta = numpyro.sample("hyper_sigma_eta", dist.HalfCauchy(1.0)) # Equivalent to 1/sigma prior eta_true = numpyro.sample("eta_true", dist.Normal(hyper_mean_eta, hyper_sigma_eta), sample_shape=(Nt,)) # Evaluate the likelihood numpyro.sample("obs", TFRLikelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, eta_true), obs=jnp.array([m_obs, eta_obs])) class TFRLikelihood(dist.Distribution): support = dist.constraints.real def __init__(self, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, eta_true): self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, self.eta_true = dist.util.promote_shapes(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, eta_true) batch_shape = lax.broadcast_shapes( jnp.shape(alpha), jnp.shape(a_TFR), jnp.shape(b_TFR), jnp.shape(sigma_TFR), jnp.shape(sigma_v), # jnp.shape(m_true), jnp.shape(eta_true), ) super(TFRLikelihood, self).__init__(batch_shape = batch_shape) def sample(self, key, sample_shape=()): raise NotImplementedError def log_prob(self, value): loglike = likelihood(self.alpha, self.a_TFR, self.b_TFR, self.sigma_TFR, self.sigma_v, m_true, self.eta_true, dens, vel, omega_m, h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh) return loglike rng_key = random.PRNGKey(6) rng_key, rng_key_ = random.split(rng_key) values = initial myprint('Preparing MCMC kernel') kernel = numpyro.infer.NUTS(tfr_model, init_strategy=numpyro.infer.initialization.init_to_value(values=initial) ) mcmc = numpyro.infer.MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) myprint('Running MCMC') mcmc.run(rng_key_) mcmc.print_summary() return mcmc def process_mcmc_run(mcmc, param_labels, truths, obs): # Convert samples into a single array samples = mcmc.get_samples() samps = jnp.empty((len(samples[param_labels[0]]), len(param_labels))) for i, p in enumerate(param_labels): samps = samps.at[:,i].set(samples[p]) # Trace plot of non-redshift quantities fig1, axs1 = plt.subplots(samps.shape[1], 1, figsize=(6,3*samps.shape[1]), sharex=True) axs1 = np.atleast_1d(axs1) for i in range(samps.shape[1]): axs1[i].plot(samps[:,i]) axs1[i].set_ylabel(param_labels[i]) axs1[i].axhline(truths[i], color='k') axs1[-1].set_xlabel('Step Number') fig1.tight_layout() fig1.savefig('trace.png') # Corner plot fig2, axs2 = plt.subplots(samps.shape[1], samps.shape[1], figsize=(10,10)) corner.corner( np.array(samps), labels=param_labels, fig=fig2, truths=truths ) fig2.savefig('corner.png') # True vs predicted for var in ['eta', 'm']: vname = var + '_true' if vname in samples.keys(): xtrue = obs[var] xpred_median = np.median(samples[vname], axis=0) xpred_plus = np.percentile(samples[vname], 84, axis=0) - xpred_median xpred_minus = xpred_median - np.percentile(samples[vname], 16, axis=0) fig3, axs3 = plt.subplots(2, 1, figsize=(10,8), sharex=True) plot_kwargs = {'fmt':'.', 'markersize':3, 'zorder':10, 'capsize':1, 'elinewidth':1, 'alpha':1} axs3[0].errorbar(xtrue, xpred_median, yerr=[xpred_minus, xpred_plus], **plot_kwargs) axs3[1].errorbar(xtrue, xpred_median - xtrue, yerr=[xpred_minus, xpred_plus], **plot_kwargs) axs3[1].set_xlabel('True') axs3[0].set_ylabel('True') axs3[1].set_ylabel('True - Predicted') xlim = axs3[0].get_xlim() ylim = axs3[0].get_ylim() axs3[0].plot(xlim, xlim, color='k', zorder=0) axs3[0].set_xlim(xlim) axs3[0].set_ylim(ylim) axs3[1].axhline(0, color='k', zorder=0) fig3.suptitle(var) fig3.align_labels() fig3.tight_layout() fig3.savefig(f'true_predicted_{var}.png') return def main(): myprint('Beginning') # Get some parameters from the data sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma = estimate_data_parameters() # Other parameters to use L = 500.0 N = 64 xmin = -L/2 R_lim = L / 2 Rmax = 100 Nt = 100 alpha = 1.4 mthresh = 11.25 a_TFR = -23 b_TFR = -8.2 sigma_TFR = 0.3 sigma_v = 150 Nint_points = 201 Nsig = 10 frac_sigma_r = 0.07 # WANT A BETTER WAY OF DOING THIS - ESTIMATE THROUGH SIGMAS FROM TFR interp_order = 1 bias_epsilon = 1.e-7 num_warmup = 1000 num_samples = 1000 prior = { 'alpha': [0.5, 2.5], 'a_TFR': [-25, -20], 'b_TFR': [-10, -5], 'hyper_mean_eta': [hyper_eta_mu - 0.5, hyper_eta_mu + 0.5], # 'hyper_mean_m':[mthresh - 5, mthresh + 5] } initial = { 'alpha': alpha, 'a_TFR': a_TFR, 'b_TFR': b_TFR, 'hyper_mean_eta': hyper_eta_mu, 'hyper_sigma_eta': hyper_eta_sigma, # 'hyper_mean_m': mthresh, 'sigma_TFR': sigma_TFR, 'sigma_v': sigma_v, } # Make mock np.random.seed(123) cpar, dens, vel = get_fields(L, N, xmin) RA, Dec, czCMB, m_true, eta_true, m_obs, eta_obs, xtrue = create_mock( Nt, L, xmin, cpar, dens, vel, Rmax, alpha, mthresh, a_TFR, b_TFR, sigma_TFR, sigma_m, sigma_eta, hyper_eta_mu, hyper_eta_sigma, sigma_v, interp_order=interp_order, bias_epsilon=bias_epsilon) MB_pos = generateMBData(RA, Dec, czCMB, L, N, R_lim, Nsig, Nint_points, sigma_v, frac_sigma_r) # Test likelihood loglike = likelihood(alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh) myprint(f'loglike {loglike}') # Scan over parameters to make plots verifying behaviour test_likelihood_scan(prior, alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, m_true, eta_true, dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh) # Run a MCMC mcmc = run_mcmc(num_warmup, num_samples, prior, initial, dens, vel, cpar.omega_m, cpar.h, L, xmin, interp_order, bias_epsilon, czCMB, m_obs, eta_obs, sigma_m, sigma_eta, MB_pos, mthresh, m_true) param_labels = ['alpha', 'a_TFR', 'b_TFR', 'sigma_TFR', 'sigma_v', 'hyper_mean_eta', 'hyper_sigma_eta'] truths = [alpha, a_TFR, b_TFR, sigma_TFR, sigma_v, hyper_eta_mu, hyper_eta_sigma] param_labels = ['hyper_mean_eta', 'hyper_sigma_eta'] truths = [hyper_eta_mu, hyper_eta_sigma] obs = {'m':m_obs, 'eta':eta_obs} process_mcmc_run(mcmc, param_labels, truths, obs) if __name__ == "__main__": main() """ TO DO - Fails to initialise currently when loglike includes the BORG term - Runs MCMC with this likelihood - Add bulk velocity - Deal with case where sigma_eta and sigma_m could be floats vs arrays """