Add HMC and BlackJax parameter samplers

This commit is contained in:
Deaglan Bartlett 2024-10-25 11:40:58 +02:00
parent fca3d19059
commit 8a0c638e5b
29 changed files with 829 additions and 160 deletions

1
.gitignore vendored
View file

@ -168,3 +168,4 @@ tests/*.h5
*timing_stats*
*fft_wisdom*
*allocation_stats*
scripts/out_files

View file

@ -1 +1,27 @@
# borg_velocity
# borg_velocity
To install in a fresh environment on infinity, follow these instructions:
```
module purge
module restore myborg
module load cuda/12.3
conda create -n borg_new
conda activate borg_new
conda config --add channels conda-forge
conda install c-compiler
conda install zlib
conda install python=3 healpy h5py numexpr numba deprecated
conda install gcc_linux-64 gxx_linux-64 mpi4py
pip install -U "jax[cuda12]=0.4.31”
pip install blackjax
cd /home/bartlett/borg
bash build.sh --c-compiler $(which x86_64-conda_cos6-linux-gnu-gcc) --cxx-compiler $(which x86_64-conda_cos6-linux-gnu-g++) --python=$(which python3) —-install-system-python --hades-python --use-system-hdf5 --build-dir /data101/bartlett/build_borg/
cd /data101/bartlett/build_borg/
make -j 32
make python-install
cd /home/bartlett/fsigma8/borg_velocity
pip3 install -e .
```

View file

@ -1,10 +1,12 @@
import numpy as np
import aquila_borg as borg
import jax.numpy as jnp
import configparser
import warnings
import aquila_borg as borg
import symbolic_pofk.linear
#import symbolic_pofk.linear
import jax
from jax import lax
import jaxlib
from functools import partial
import ast
@ -12,11 +14,11 @@ import numbers
import h5py
import borg_velocity.utils as utils
from borg_velocity.utils import myprint
from borg_velocity.utils import myprint, compute_As, get_sigma_bulk
import borg_velocity.forwards as forwards
import borg_velocity.mock_maker as mock_maker
import borg_velocity.projection as projection
from borg_velocity.samplers import HMCBiasSampler, derive_prior
from borg_velocity.samplers import HMCBiasSampler, derive_prior, MVSliceBiasSampler, BlackJaxBiasSampler
import borg_velocity.samplers as samplers
class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
@ -159,8 +161,9 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
cpar = cosmo
# Convert sigma8 to As
cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
cpar = compute_As(cpar)
# cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
# cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
myprint(f"Updating cosmology Om = {cosmo.omega_m}, sig8 = {cosmo.sigma8}, As = {cosmo.A_s}")
cpar.omega_q = 1. - cpar.omega_m - cpar.omega_k
@ -260,7 +263,7 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
sig_v = self.model_params['sig_v']
# Compute velocity field
bulk_flow = jnp.array([self.model_params['bulk_flow_x'],
self.bulk_flow = jnp.array([self.model_params['bulk_flow_x'],
self.model_params['bulk_flow_y'],
self.model_params['bulk_flow_z']])
v = output_velocity + self.bulk_flow.reshape((3, 1, 1, 1))
@ -291,8 +294,18 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
self.R_max
)
if not jnp.isfinite(lkl):
lkl = self.bignum
# Add in bulk flow prior
# sigma_bulk = get_sigma_bulk(self.L[0], self.fwd.getCosmoParams())
# lkl += jnp.sum(0.5 * jnp.log(2 * np.pi) + jnp.log(sigma_bulk) + self.bulk_flow ** 2 / 2. / sigma_bulk ** 2)
# if not jnp.isfinite(lkl):
# lkl = self.bignum
lkl = lax.cond(
jnp.isfinite(lkl),
lambda _: lkl, # If True (finite), return lkl
lambda _: self.bignum, # If False (not finite), return self.bignum
operand=None # No additional operand needed here
)
return lkl
@ -441,14 +454,14 @@ def vel2like(cz_obs, v, MB_field, MB_pos, r, r_hMpc, sig_mu, sig_v, omega_m, muA
# Multiply p_r by arbitrary number for numerical stability (cancels in p_r / p_r_norm)
d2 = d2 - jnp.expand_dims(jnp.nanmin(d2, axis=1), axis=1)
p_r = r ** 2 * jnp.exp(-0.5 * d2) * los_density * jnp.exp(- lam * r / R_max)
p_r_norm = jnp.expand_dims(jnp.trapz(p_r, r, axis=1), axis=1)
p_r_norm = jnp.expand_dims(jnp.trapezoid(p_r, r, axis=1), axis=1)
# Integrate to get likelihood
d2 = delta_cz_sigv**2
scale = jnp.nanmin(d2, axis=1)
d2 = d2 - jnp.expand_dims(scale, axis=1)
exp_delta_cz = jnp.exp(-0.5 * d2)
p_cz = jnp.trapz(exp_delta_cz * p_r / p_r_norm, r, axis=1)
p_cz = jnp.trapezoid(exp_delta_cz * p_r / p_r_norm, r, axis=1)
lkl_ind = jnp.log(p_cz) - scale / 2 - 0.5 * jnp.log(2 * np.pi * sig_v**2)
lkl = - lkl_ind.sum()
@ -772,6 +785,27 @@ def build_sampler(
with open('model_params.txt', 'w') as file:
for item in params:
file.write(f"{item}\n")
elif config['sampling']['algorithm'].lower() == 'mvslice':
model_sampler = MVSliceBiasSampler(
"model_params",
likelihood,
params,
)
with open('model_params.txt', 'w') as file:
for item in params:
file.write(f"{item}\n")
elif config['sampling']['algorithm'].lower() == 'blackjax':
model_sampler = BlackJaxBiasSampler(
"model_params",
likelihood,
params,
int(config['sampling']['rng_seed']),
int(config['sampling']['warmup_nsteps']),
float(config['sampling']['warmup_target_acceptance_rate']),
)
with open('model_params.txt', 'w') as file:
for item in params:
file.write(f"{item}\n")
else:
raise NotImplementedError
model_sampler.setName("model_sampler")
@ -781,7 +815,7 @@ def build_sampler(
loop.addToConditionGroup("warmup_model", "model_sampler")
loop.addConditionToConditionGroup("warmup_model", partial(check_model_sampling, loop))
myprint('Done')
return []

View file

@ -3,6 +3,7 @@ import aquila_borg as borg
from typing import List
import jax.numpy as jnp
import jax
import blackjax
from borg_velocity.utils import myprint
@ -65,6 +66,7 @@ class HMCBiasSampler(borg.samplers.PyBaseSampler):
if self.transform_attributes[i]:
xnew = self.transform_attributes[i](xnew)
self.likelihood.model_params[self.attributes[i]] = xnew
def _likelihood(self, x_hat, x, skip_density=False):
self._update_attrs(x)
@ -204,3 +206,182 @@ def inv_transform_uniform(alpha, a, b):
"""
x = (alpha - a) / (b - a)
return - jnp.log(1/x - 1)
class MVSliceBiasSampler(borg.samplers.PyBaseSampler):
def __init__(
self,
prefix: str,
likelihood: borg.likelihood.Likelihood3d,
attributes: List[str],
):
super().__init__()
self.likelihood = likelihood
self.attributes = attributes
self.prefix = prefix
self.lam = 1.0
def initialize(self, state: borg.likelihood.MarkovState):
self.restore(state)
for i in range(len(self.attributes)):
self.y[i] = self.likelihood.fwd_param.getModelParam(
'mod_null',
self.attributes[i]
)
self.mvs_state[:] = 0 # np.random.uniform(size=len(self.attributes))
def _update_attrs(self, x):
for i in range(len(self.attributes)):
self.likelihood.model_params[self.attributes[i]] = x[i]
self.likelihood.fwd_param.setModelParams(self.likelihood.model_params)
def sample(self, state: borg.likelihood.MarkovState):
myprint(f"Sampling attributes {self.attributes}")
x_hat = state["s_hat_field"]
def _callback(x):
# print("Callback: " + repr(x))
self._update_attrs(x)
return -self.likelihood.logLikelihoodComplex(x_hat, False)
self.y[:] = borg.samplers.mv_slice_sampler(
state, _callback, self.y, self.mvs_state, self.lam
)
self._update_attrs(self.y)
def restore(self, state: borg.likelihood.MarkovState):
attrname = f"{self.prefix}attributes"
state.newArray1d(attrname, len(self.attributes), True)
self.y = state[attrname]
def _loaded():
myprint(f"Reinjecting parameters for bias sampler: {self.attributes}")
self._update_attrs(self.y)
state.subscribeLoaded(attrname, _loaded)
mvsname = f"{self.prefix}mvs_state"
state.newArray1d(mvsname, len(self.attributes), True)
self.mvs_state = state[mvsname]
class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
def __init__(
self,
prefix: str,
likelihood: borg.likelihood.Likelihood3d,
attributes: List[str],
rng_seed: int,
warmup_nsteps: int,
warmup_target_acceptance_rate: float
):
super().__init__()
self.likelihood = likelihood
self.attributes = attributes
self.prefix = prefix
self.rng_key = jax.random.key(rng_seed)
self.warmup_nsteps = warmup_nsteps
self.warmup_target_acceptance_rate = warmup_target_acceptance_rate
# CHANGE THIS
# self.step_size = 1e-4
# self.inverse_mass_matrix = np.ones(len(attributes))
self.parameters = None
print("MY ATTRIBUTES:", attributes)
def initialize(self, state: borg.likelihood.MarkovState):
self.restore(state)
for i in range(len(self.attributes)):
self.y[i] = self.likelihood.fwd_param.getModelParam(
'mod_null',
self.attributes[i]
)
self.mvs_state[:] = 0 # np.random.uniform(size=len(self.attributes))
def _update_attrs(self, x):
for i in range(len(self.attributes)):
self.likelihood.model_params[self.attributes[i]] = x[i]
# self.likelihood.fwd_param.setModelParams(self.likelihood.model_params)
def _likelihood(self, x_hat, x, skip_density=False):
self._update_attrs(x)
return self.likelihood.logLikelihoodComplex(
x_hat, False, skip_density=skip_density, update_from_model=False,
)
def sample(self, state: borg.likelihood.MarkovState):
myprint(f"Sampling attributes {self.attributes}")
x_hat = state["s_hat_field"]
# Run forward model to compute correct density field
_ = self.likelihood.logLikelihoodComplex(x_hat, False, skip_density=False)
@jax.custom_vjp
def logdensity_fn(x):
self._update_attrs(x)
return -self.likelihood.logLikelihoodComplex(x_hat, False, skip_density=True, update_from_model=False)
def logdensity_fn_fwd(x):
y = logdensity_fn(x)
return y, x # Save `x` for the backward pass
def logdensity_fn_bwd(x, cotangent_y):
grad_fn = jax.grad(self._likelihood, argnums=1)
grad_x = grad_fn(x_hat, x, skip_density=True)
cotangent_x = grad_x * cotangent_y
return (cotangent_x,)
logdensity_fn.defvjp(logdensity_fn_fwd, logdensity_fn_bwd)
if self.parameters is None:
get_info = blackjax.adaptation.base.get_filter_adapt_info_fn(
state_keys={'position', 'logdensity'},
info_keys={'acceptance_rate'},
adapt_state_keys={'step_size', 'inverse_mass_matrix'}
)
logdensity_fn_jitted = jax.jit(logdensity_fn)
warmup = blackjax.window_adaptation(
blackjax.nuts,
logdensity_fn_jitted,
target_acceptance_rate=self.warmup_target_acceptance_rate,
progress_bar=True,
adaptation_info_fn=get_info
)
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
(state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
# x = info.state.position
self.y[:] = state.position
nuts = blackjax.nuts(logdensity_fn, **self.parameters)
bias_state = nuts.init(self.y)
nuts_step = nuts.step
self.rng_key, sample_key = jax.random.split(self.rng_key)
bias_state, _ = nuts_step(sample_key, bias_state)
self.y[:] = bias_state.position
self._update_attrs(self.y)
to_set = {k:v for k, v in zip(self.attributes, self.y)}
self.likelihood.fwd_param.setModelParams(to_set)
def restore(self, state: borg.likelihood.MarkovState):
attrname = f"{self.prefix}attributes"
state.newArray1d(attrname, len(self.attributes), True)
self.y = state[attrname]
def _loaded():
myprint(f"Reinjecting parameters for bias sampler: {self.attributes}")
self._update_attrs(self.y)
state.subscribeLoaded(attrname, _loaded)
mvsname = f"{self.prefix}mvs_state"
state.newArray1d(mvsname, len(self.attributes), True)
self.mvs_state = state[mvsname]

View file

@ -1,15 +1,52 @@
import aquila_borg as borg
import configparser
import os
import symbolic_pofk.linear
#import symbolic_pofk.linear
from functools import partial
import jax
import jax.numpy as jnp
import scipy.integrate
import math
# Output stream management
cons = borg.console()
myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x))
a = jnp.array([1.0])
x = jnp.copy(a)
def compute_As(cpar):
"""
Compute As given values of sigma8
Args:
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters with wrong As
Returns:
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters with updated As
"""
# requires BORG-CLASS
if not hasattr(borg.cosmo, 'ClassCosmo'):
raise ImportError(
"BORG-CLASS is required to compute As, but is not installed.")
sigma8_true = jnp.copy(cpar.sigma8)
cpar.sigma8 = 0
cpar.A_s = 2.3e-9
k_max, k_per_decade = 10, 100
extra_class = {}
extra_class['YHe'] = '0.24'
cosmo = borg.cosmo.ClassCosmo(cpar, k_per_decade, k_max, extra=extra_class)
cosmo.computeSigma8()
cos = cosmo.getCosmology()
cpar.A_s = (sigma8_true/cos['sigma_8'])**2*cpar.A_s
cpar.sigma8 = sigma8_true
return cpar
def get_cosmopar(ini_file):
"""
@ -37,8 +74,10 @@ def get_cosmopar(ini_file):
cpar.n_s = float(config['cosmology']['n_s'])
cpar.w = float(config['cosmology']['w'])
cpar.wprime = float(config['cosmology']['wprime'])
cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
# cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
# cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
cpar = compute_As(cpar)
return cpar
@ -90,4 +129,60 @@ def z_cos(r_hMpc: float, Omega_m: float) -> float:
"""
Omega_L = 1. - Omega_m
q0 = Omega_m/2.0 - Omega_L
return (1.0 - jnp.sqrt(1 - 2*r_hMpc*100*(1 + q0)/speed_of_light))/(1.0 + q0)
return (1.0 - jnp.sqrt(1 - 2*r_hMpc*100*(1 + q0)/speed_of_light))/(1.0 + q0)
def get_sigma_bulk(R, cpar):
"""
Compute the bulk flow variance when the field is smoothed with a top
hat filter of side length R.
sigma^2 = H_0^2 f^2 / (2 \pi^2) \int_0^\infty dk W^2(k,R) P(k)
where W(k,R) is the Fourier transform of the top hat filter
Args:
:R (float): Box length
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters with wrong As
Returns:
:sigma (float): Bulk flow RMS [km/s]
"""
k_max, k_per_decade = 10, 100
extra_class = {}
extra_class['YHe'] = '0.24'
cosmo = borg.cosmo.ClassCosmo(cpar, k_per_decade, k_max, extra=extra_class)
# Compute integrand
k = jnp.logspace(-4, 1, 300) # (h / Mpc)
plin_class = cosmo.get_Pk_matter(k) # CHECK: Is this in (Mpc/h)^3?
x = k * R
integrand = plin_class * (3 * (jnp.sin(x) - x * jnp.cos(x)) / x ** 3) ** 2 # (Mpc/h)^3
# Check units
sigma = scipy.integrate.simpson(integrand, x=k) # (Mpc/h)^2
H0 = 100 # h km/s/Mpc
f = cpar.omega_m ** 0.55
sigma = jnp.sqrt(sigma / 2) * H0 * f / math.pi # (Mpc/h) * h km/s/Mpc = km/s
return sigma
if __name__ == "__main__":
cpar = get_cosmopar('../conf/basic_ini.ini')
print(cpar)
cpar = compute_As(cpar)
print(cpar)
print(1.e-9 * symbolic_pofk.linear.sigma8_to_As(
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s))
print(get_sigma_bulk(500, cpar))
"""
- Find where else symbolic As is used
- Add prior for sigma_bulk
"""

View file

@ -18,9 +18,9 @@ seed_cpower = true
hades_sampler_blocked = true
bias_sampler_blocked= true
nmean_sampler_blocked= true
sigma8_sampler_blocked = false
sigma8_sampler_blocked = true
omega_m_sampler_blocked = true
muA_sampler_blocked = false
omega_m_sampler_blocked = false
alpha_sampler_blocked = false
lam_sampler_blocked = false
sig_v_sampler_blocked = false
@ -42,10 +42,13 @@ max_timesteps = 50
mixing = 1
[sampling]
algorithm = HMC
epsilon = 0.001
algorithm = blackjax
epsilon = 0.005
Nsteps = 20
refresh = 0.1
rng_seed = 1
warmup_nsteps = 100
warmup_target_acceptance_rate = 0.7
[model]
gravity = lpt

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 15 KiB

BIN
figs/bulk_test.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 23 KiB

After

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 169 KiB

After

Width:  |  Height:  |  Size: 170 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 98 KiB

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 19 KiB

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 50 KiB

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 84 KiB

After

Width:  |  Height:  |  Size: 97 KiB

BIN
figs/spectra_delta.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

BIN
figs/spectra_ics.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 180 KiB

After

Width:  |  Height:  |  Size: 305 KiB

File diff suppressed because one or more lines are too long

221
notebooks/Bulk_flow.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -9,7 +9,10 @@ module load cuda/12.3
source /home/bartlett/.bashrc
source /home/bartlett/anaconda3/etc/profile.d/conda.sh
conda deactivate
conda activate borg_env
#conda activate borg_env
#conda activate borg
#conda activate compiler_test
conda activate borg_new
# Kill job if there are any errors
set -e
@ -21,9 +24,10 @@ BUILD_DIR=/data101/bartlett/build_borg/
cd $BORG_DIR
# git pull
# rm -rf $BUILD_DIR
bash build.sh --c-compiler $(which x86_64-conda_cos6-linux-gnu-gcc) --cxx-compiler $(which x86_64-conda_cos6-linux-gnu-g++) --python --hades-python --build-dir $BUILD_DIR
#bash build.sh --c-compiler $(which x86_64-conda_cos6-linux-gnu-gcc) --cxx-compiler $(which x86_64-conda_cos6-linux-gnu-g++) --python --hades-python --build-dir $BUILD_DIR
bash build.sh --c-compiler $(which x86_64-conda_cos6-linux-gnu-gcc) --cxx-compiler $(which x86_64-conda_cos6-linux-gnu-g++) --python=$(which python3) --install-system-python --hades-python --use-system-hdf5 --build-dir /data101/bartlett/build_borg/
cd $BUILD_DIR
make -j
#make -j 32
exit 0
exit 0

View file

@ -3,20 +3,22 @@
# Modules
module purge
module restore myborg
module load cuda/12.3
module load cuda/12.6
# Environment
source /home/bartlett/.bashrc
source /home/bartlett/anaconda3/etc/profile.d/conda.sh
conda deactivate
conda activate borg_env
#conda activate borg_env
#conda activate compiler_test
conda activate borg_new
# Kill job if there are any errors
set -e
# Path variables
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/supranta_pars_N64
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/test_dir_blackjax
mkdir -p $RUN_DIR
cd $RUN_DIR

View file

@ -1,9 +1,9 @@
#!/bin/sh
#PBS -S /bin/sh
#PBS -N supranta_pars_N64
#PBS -N supranta_N64_hmc_modelpar
#PBS -j oe
#PBS -m ae
#PBS -l nodes=h11:has1gpu:ppn=40,walltime=72:00:00
#PBS -l nodes=h10:has1gpu:ppn=40,walltime=5:00:00
# Modules
module purge
@ -21,7 +21,7 @@ set -e
# Path variables
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/supranta_pars_N64
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/test_dir_hmc_modelpar
mkdir -p $RUN_DIR
cd $RUN_DIR

View file

@ -14,6 +14,7 @@ test_omegam = True
test_alpha = True
test_lam = True
test_muA = True
test_bulk = True
# Input box
box_in = borg.forward.BoxModel()
@ -181,4 +182,36 @@ if test_muA:
fig.savefig('../figs/muA_test.png')
fig.clf()
plt.close(fig)
# Test bulk flow
if test_bulk:
all_vx = np.linspace(-100, 100, 50)
all_lkl = np.empty(all_vx.shape)
for i, vx in enumerate(all_vx):
mylike.fwd_param.setModelParams({'bulk_flow_x':vx})
all_lkl[i] = mylike.logLikelihoodComplex(s_hat, None)
mylike.fwd_param.setModelParams({'bulk_flow_x':mylike.bulk_flow[0]})
# fid_lkl = mylike.logLikelihoodComplex(s_hat, None)
fid_lkl = np.amin(all_lkl)
all_lkl -= fid_lkl
all_lkl = np.exp(-all_lkl)
sigma_bulk = utils.get_sigma_bulk(mylike.L[0], mylike.fwd.getCosmoParams())
ref_lkl = np.exp(-all_vx ** 2 / 2 / sigma_bulk ** 2)
ref_lkl *= np.amax(all_lkl) / np.amax(ref_lkl)
fig, ax = plt.subplots(1, 1, figsize=(5,5))
ax.plot(all_vx, all_lkl, label='Posterior')
ax.plot(all_vx, ref_lkl, ls='--', label=r'$\Lambda$CDM prior')
ax.axhline(y=0, color='k')
ax.axvline(x=mylike.muA[0], color='k')
ax.legend()
ax.set_xlabel(r'$v_{{\rm bulk}, x}$')
ax.set_ylabel(r'$\mathcal{L}$')
fig.tight_layout()
fig.savefig('../figs/bulk_test.png')
fig.clf()
plt.close(fig)