Enable restarts in BlackJax samplers and with reloading mock data
This commit is contained in:
parent
8a0c638e5b
commit
e16b14e867
8 changed files with 307 additions and 52 deletions
|
@ -12,13 +12,14 @@ from functools import partial
|
|||
import ast
|
||||
import numbers
|
||||
import h5py
|
||||
import re
|
||||
|
||||
import borg_velocity.utils as utils
|
||||
from borg_velocity.utils import myprint, compute_As, get_sigma_bulk
|
||||
import borg_velocity.forwards as forwards
|
||||
import borg_velocity.mock_maker as mock_maker
|
||||
import borg_velocity.projection as projection
|
||||
from borg_velocity.samplers import HMCBiasSampler, derive_prior, MVSliceBiasSampler, BlackJaxBiasSampler
|
||||
from borg_velocity.samplers import HMCBiasSampler, derive_prior, MVSliceBiasSampler, BlackJaxBiasSampler, TransformedBlackJaxBiasSampler
|
||||
import borg_velocity.samplers as samplers
|
||||
|
||||
class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
||||
|
@ -112,6 +113,12 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
# Initialise derivative
|
||||
self.grad_like = jax.grad(self.dens2like, argnums=(0,1))
|
||||
|
||||
# Find out what kind of run this is
|
||||
if borg.EMBEDDED:
|
||||
self.action = utils.get_action()
|
||||
else:
|
||||
self.action = None
|
||||
|
||||
|
||||
def initializeLikelihood(self, state: borg.likelihood.MarkovState) -> None:
|
||||
"""
|
||||
|
@ -135,6 +142,17 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
# self.data = [state[f"tracer_vr_{i}"] for i in range(self.nsamp)]
|
||||
state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True)
|
||||
|
||||
if self.run_type == 'data':
|
||||
self.loadObservedData(make_plot=False)
|
||||
else:
|
||||
if self.action == 'INIT':
|
||||
pass # Data will be loaded later
|
||||
elif self.action == 'RESUME':
|
||||
self.loadMockData(state) # load data from mock_data.h5
|
||||
else:
|
||||
myprint(f'Unknown action: {self.action}')
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def updateMetaParameters(self, state: borg.likelihood.MarkovState) -> None:
|
||||
"""
|
||||
|
@ -246,6 +264,30 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
|
|||
myprint('From mock')
|
||||
self.saved_s_hat = s_hat.copy()
|
||||
self.logLikelihoodComplex(s_hat, False)
|
||||
|
||||
|
||||
def loadMockData(self, state: borg.likelihood.MarkovState) -> None:
|
||||
|
||||
myprint('Loading previously made mock data')
|
||||
|
||||
self.coord_true = [None] * self.nsamp
|
||||
self.coord_meas = [None] * self.nsamp
|
||||
self.sig_mu = [None] * self.nsamp
|
||||
self.vr_true = [None] * self.nsamp
|
||||
self.cz_obs = [None] * self.nsamp
|
||||
self.r_hMpc = [None] * self.nsamp
|
||||
|
||||
with h5py.File(f'tracer_data_{self.run_type}.h5', 'r') as f:
|
||||
for i in range(self.nsamp):
|
||||
self.coord_true[i] = jnp.array(f[f'sample_{i}/coord_true'][:])
|
||||
self.coord_meas[i] = jnp.array(f[f'sample_{i}/coord_meas'][:])
|
||||
self.sig_mu[i] = f[f'sample_{i}/sig_mu'][()]
|
||||
self.vr_true[i] = jnp.array(f[f'sample_{i}/vr_true'][:])
|
||||
self.cz_obs[i] = jnp.array(f[f'sample_{i}/cz_obs'][:])
|
||||
self.r_hMpc[i] = np.sqrt(np.sum(self.coord_meas[i] ** 2, axis=0))
|
||||
|
||||
self.generateMBData()
|
||||
|
||||
|
||||
def dens2like(self, output_density: np.ndarray, output_velocity: np.ndarray):
|
||||
"""
|
||||
|
@ -786,15 +828,28 @@ def build_sampler(
|
|||
for item in params:
|
||||
file.write(f"{item}\n")
|
||||
elif config['sampling']['algorithm'].lower() == 'mvslice':
|
||||
lam = [None] * len(params)
|
||||
for i, p in enumerate(params):
|
||||
if p.startswith('bulk_flow'):
|
||||
lam[i] = float(config['sampling'][f'mvlam_bulk_flow'])
|
||||
elif p[-1].isdigit():
|
||||
# Search for a non-digit before end of reversed string
|
||||
match = re.search(r'\D(?=\d*$)', p[::-1])
|
||||
lam[i] = float(config['sampling'][f'mvlam_{p[:match.start()]}'])
|
||||
else:
|
||||
lam[i] = float(config['sampling'][f'mvlam_{p}'])
|
||||
model_sampler = MVSliceBiasSampler(
|
||||
"model_params",
|
||||
likelihood,
|
||||
params,
|
||||
lam
|
||||
)
|
||||
with open('model_params.txt', 'w') as file:
|
||||
for item in params:
|
||||
file.write(f"{item}\n")
|
||||
elif config['sampling']['algorithm'].lower() == 'blackjax':
|
||||
# state.newArray1d("inverse_mass_matrix", len(params), True)
|
||||
# state.newScalar('step_size', 0., True)
|
||||
model_sampler = BlackJaxBiasSampler(
|
||||
"model_params",
|
||||
likelihood,
|
||||
|
@ -806,6 +861,21 @@ def build_sampler(
|
|||
with open('model_params.txt', 'w') as file:
|
||||
for item in params:
|
||||
file.write(f"{item}\n")
|
||||
elif config['sampling']['algorithm'].lower() == 'transformedblackjax':
|
||||
model_sampler = TransformedBlackJaxBiasSampler(
|
||||
"model_params",
|
||||
likelihood,
|
||||
params,
|
||||
transform_attributes=transform_attributes,
|
||||
inv_transform_attributes=inv_transform_attributes,
|
||||
prior = [p.prior for p in transform_attributes],
|
||||
rng_seed = int(config['sampling']['rng_seed']),
|
||||
warmup_nsteps = int(config['sampling']['warmup_nsteps']),
|
||||
warmup_target_acceptance_rate = float(config['sampling']['warmup_target_acceptance_rate']),
|
||||
)
|
||||
with open('model_params.txt', 'w') as file:
|
||||
for item in params:
|
||||
file.write(f"{item}\n")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_sampler.setName("model_sampler")
|
||||
|
|
|
@ -214,12 +214,13 @@ class MVSliceBiasSampler(borg.samplers.PyBaseSampler):
|
|||
prefix: str,
|
||||
likelihood: borg.likelihood.Likelihood3d,
|
||||
attributes: List[str],
|
||||
lam: List[float],
|
||||
):
|
||||
super().__init__()
|
||||
self.likelihood = likelihood
|
||||
self.attributes = attributes
|
||||
self.prefix = prefix
|
||||
self.lam = 1.0
|
||||
self.lam = np.array(lam)
|
||||
|
||||
def initialize(self, state: borg.likelihood.MarkovState):
|
||||
self.restore(state)
|
||||
|
@ -240,7 +241,6 @@ class MVSliceBiasSampler(borg.samplers.PyBaseSampler):
|
|||
x_hat = state["s_hat_field"]
|
||||
|
||||
def _callback(x):
|
||||
# print("Callback: " + repr(x))
|
||||
self._update_attrs(x)
|
||||
return -self.likelihood.logLikelihoodComplex(x_hat, False)
|
||||
|
||||
|
@ -282,15 +282,10 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
|||
self.rng_key = jax.random.key(rng_seed)
|
||||
self.warmup_nsteps = warmup_nsteps
|
||||
self.warmup_target_acceptance_rate = warmup_target_acceptance_rate
|
||||
|
||||
# CHANGE THIS
|
||||
# self.step_size = 1e-4
|
||||
# self.inverse_mass_matrix = np.ones(len(attributes))
|
||||
self.parameters = None
|
||||
|
||||
print("MY ATTRIBUTES:", attributes)
|
||||
|
||||
def initialize(self, state: borg.likelihood.MarkovState):
|
||||
print(dir(state))
|
||||
self.restore(state)
|
||||
for i in range(len(self.attributes)):
|
||||
self.y[i] = self.likelihood.fwd_param.getModelParam(
|
||||
|
@ -313,6 +308,176 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
|||
def sample(self, state: borg.likelihood.MarkovState):
|
||||
myprint(f"Sampling attributes {self.attributes}")
|
||||
|
||||
# Update the sampling hyperparameters if possible
|
||||
if (self.parameters is None) and (state["step_size"] > 0) and all(state["inverse_mass_matrix"] > 0):
|
||||
myprint('Using loaded hyperparameters for model parameter sampling')
|
||||
myprint(f'Step size: {state["step_size"]}')
|
||||
myprint(f'Inverse mass matrix: {state["inverse_mass_matrix"]}')
|
||||
self.parameters = {"step_size": state["step_size"], "inverse_mass_matrix":state["inverse_mass_matrix"]}
|
||||
|
||||
x_hat = state["s_hat_field"]
|
||||
|
||||
# Run forward model to compute correct density field
|
||||
_ = self.likelihood.logLikelihoodComplex(x_hat, False, skip_density=False)
|
||||
|
||||
@jax.custom_vjp
|
||||
def logdensity_fn(x):
|
||||
return - self._likelihood(x_hat, x, skip_density=True)
|
||||
|
||||
def logdensity_fn_fwd(x):
|
||||
y = logdensity_fn(x)
|
||||
return y, x # Save `x` for the backward pass
|
||||
|
||||
def logdensity_fn_bwd(x, cotangent_y):
|
||||
grad_fn = jax.grad(self._likelihood, argnums=1)
|
||||
grad_x = - grad_fn(x_hat, x, skip_density=True)
|
||||
cotangent_x = grad_x * cotangent_y
|
||||
return (cotangent_x,)
|
||||
|
||||
logdensity_fn.defvjp(logdensity_fn_fwd, logdensity_fn_bwd)
|
||||
|
||||
if self.parameters is None:
|
||||
get_info = blackjax.adaptation.base.get_filter_adapt_info_fn(
|
||||
state_keys={'position', 'logdensity'},
|
||||
info_keys={'acceptance_rate'},
|
||||
adapt_state_keys={'step_size', 'inverse_mass_matrix'}
|
||||
)
|
||||
logdensity_fn_jitted = jax.jit(logdensity_fn)
|
||||
warmup = blackjax.window_adaptation(
|
||||
blackjax.nuts,
|
||||
logdensity_fn_jitted,
|
||||
target_acceptance_rate=self.warmup_target_acceptance_rate,
|
||||
progress_bar=True,
|
||||
adaptation_info_fn=get_info
|
||||
)
|
||||
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
|
||||
(bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
|
||||
# x = info.state.position
|
||||
self.y[:] = bj_state.position
|
||||
|
||||
# Save results to the borg state object
|
||||
# state.newArray1d("inverse_mass_matrix", len(self.parameters['inverse_mass_matrix']), True)
|
||||
state["inverse_mass_matrix"][:] = self.parameters['inverse_mass_matrix']
|
||||
state["step_size"] = float(self.parameters['step_size'])
|
||||
|
||||
nuts = blackjax.nuts(logdensity_fn, **self.parameters)
|
||||
bias_state = nuts.init(self.y)
|
||||
nuts_step = nuts.step
|
||||
|
||||
self.rng_key, sample_key = jax.random.split(self.rng_key)
|
||||
bias_state, _ = nuts_step(sample_key, bias_state)
|
||||
|
||||
self.y[:] = bias_state.position
|
||||
self._update_attrs(self.y)
|
||||
|
||||
to_set = {k:v for k, v in zip(self.attributes, self.y)}
|
||||
self.likelihood.fwd_param.setModelParams(to_set)
|
||||
|
||||
|
||||
def restore(self, state: borg.likelihood.MarkovState):
|
||||
|
||||
state.newArray1d("inverse_mass_matrix", len(self.attributes), True)
|
||||
state.newScalar('step_size', 0., True)
|
||||
|
||||
attrname = f"{self.prefix}attributes"
|
||||
state.newArray1d(attrname, len(self.attributes), True)
|
||||
self.y = state[attrname]
|
||||
|
||||
def _loaded():
|
||||
myprint(f"Reinjecting parameters for bias sampler: {self.attributes}")
|
||||
self._update_attrs(self.y)
|
||||
|
||||
state.subscribeLoaded(attrname, _loaded)
|
||||
|
||||
mvsname = f"{self.prefix}mvs_state"
|
||||
state.newArray1d(mvsname, len(self.attributes), True)
|
||||
self.mvs_state = state[mvsname]
|
||||
|
||||
|
||||
class TransformedBlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
likelihood: borg.likelihood.BaseLikelihood,
|
||||
attributes: List[str],
|
||||
transform_attributes: List[object],
|
||||
inv_transform_attributes: List[object],
|
||||
prior: List[object],
|
||||
rng_seed: int,
|
||||
warmup_nsteps: int,
|
||||
warmup_target_acceptance_rate: float
|
||||
):
|
||||
super().__init__()
|
||||
self.likelihood = likelihood
|
||||
self.attributes = attributes
|
||||
self.prefix = prefix
|
||||
self.transform_attributes = transform_attributes
|
||||
self.inv_transform_attributes = inv_transform_attributes
|
||||
self.prior_attributes = prior
|
||||
self.rng_key = jax.random.key(rng_seed)
|
||||
self.warmup_nsteps = warmup_nsteps
|
||||
self.warmup_target_acceptance_rate = warmup_target_acceptance_rate
|
||||
self.parameters = None
|
||||
|
||||
|
||||
def initialize(self, state: borg.likelihood.MarkovState):
|
||||
self.restore(state)
|
||||
for i in range(len(self.attributes)):
|
||||
attr = self.likelihood.fwd_param.getModelParam(
|
||||
'mod_null',
|
||||
self.attributes[i]
|
||||
)
|
||||
self.y_real[i] = attr
|
||||
if isinstance(attr, float):
|
||||
attr = jnp.array(attr)
|
||||
if self.inv_transform_attributes[i]:
|
||||
self.y[i] = self.inv_transform_attributes[i](attr).item()
|
||||
else:
|
||||
self.y[i] = attr.item()
|
||||
|
||||
myprint(f"initial {self.y=}")
|
||||
|
||||
def _update_attrs(self, x):
|
||||
for i in range(len(self.attributes)):
|
||||
xnew = x[i]
|
||||
if self.transform_attributes[i]:
|
||||
xnew = self.transform_attributes[i](xnew)
|
||||
self.likelihood.model_params[self.attributes[i]] = xnew
|
||||
|
||||
|
||||
def _likelihood(self, x_hat, x, skip_density=False):
|
||||
self._update_attrs(x)
|
||||
return self.likelihood.logLikelihoodComplex(
|
||||
x_hat, False, skip_density=skip_density, update_from_model=False,
|
||||
)
|
||||
|
||||
def _grad_likelihood(self, x_hat, x, skip_density=False):
|
||||
grad_fn = jax.grad(self._likelihood, argnums=1)
|
||||
grad_x = grad_fn(x_hat, x, skip_density=skip_density)
|
||||
return grad_x
|
||||
|
||||
def _prior(self, x):
|
||||
prior_value = 0.0
|
||||
for i in range(len(self.attributes)):
|
||||
if self.prior_attributes[i]:
|
||||
prior_value += self.prior_attributes[i](x[i])
|
||||
return prior_value
|
||||
|
||||
def _grad_prior(self, x):
|
||||
grad_x = jax.grad(self._prior, argnums=0)(x)
|
||||
return grad_x
|
||||
|
||||
|
||||
def sample(self, state: borg.likelihood.MarkovState):
|
||||
myprint(f"Sampling attributes {self.attributes}")
|
||||
|
||||
# Update the sampling hyperparameters if possible
|
||||
if (self.parameters is None) and (state["step_size"] > 0) and all(state["inverse_mass_matrix"] > 0):
|
||||
myprint('Using loaded hyperparameters for model parameter sampling')
|
||||
myprint(f'Step size: {state["step_size"]}')
|
||||
myprint(f'Inverse mass matrix: {state["inverse_mass_matrix"]}')
|
||||
self.parameters = {"step_size": state["step_size"], "inverse_mass_matrix":state["inverse_mass_matrix"]}
|
||||
|
||||
x_hat = state["s_hat_field"]
|
||||
|
||||
# Run forward model to compute correct density field
|
||||
|
@ -321,15 +486,15 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
|||
@jax.custom_vjp
|
||||
def logdensity_fn(x):
|
||||
self._update_attrs(x)
|
||||
return -self.likelihood.logLikelihoodComplex(x_hat, False, skip_density=True, update_from_model=False)
|
||||
p = self._likelihood(x_hat, x, skip_density=True) + self._prior(x)
|
||||
return - p
|
||||
|
||||
def logdensity_fn_fwd(x):
|
||||
y = logdensity_fn(x)
|
||||
return y, x # Save `x` for the backward pass
|
||||
|
||||
def logdensity_fn_bwd(x, cotangent_y):
|
||||
grad_fn = jax.grad(self._likelihood, argnums=1)
|
||||
grad_x = grad_fn(x_hat, x, skip_density=True)
|
||||
grad_x = - (self._grad_likelihood(x_hat, x, skip_density=True) + self._grad_prior(x))
|
||||
cotangent_x = grad_x * cotangent_y
|
||||
return (cotangent_x,)
|
||||
|
||||
|
@ -363,25 +528,37 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
|||
|
||||
self.y[:] = bias_state.position
|
||||
self._update_attrs(self.y)
|
||||
|
||||
for i in range(len(self.attributes)):
|
||||
self.y_real[i] = self.likelihood.model_params[self.attributes[i]]
|
||||
|
||||
to_set = {k:v for k, v in zip(self.attributes, self.y)}
|
||||
to_set = {k:v for k, v in zip(self.attributes, self.y_real)}
|
||||
self.likelihood.fwd_param.setModelParams(to_set)
|
||||
|
||||
|
||||
def restore(self, state: borg.likelihood.MarkovState):
|
||||
attrname = f"{self.prefix}attributes"
|
||||
state.newArray1d(attrname, len(self.attributes), True)
|
||||
self.y = state[attrname]
|
||||
|
||||
state.newArray1d("inverse_mass_matrix", len(self.attributes), True)
|
||||
state.newScalar('step_size', 0., True)
|
||||
|
||||
# Define attribute names
|
||||
attrname_real = self.prefix
|
||||
attrname_var = f"{self.prefix}_hmc_var"
|
||||
|
||||
# Initialize attributes in the state
|
||||
state.newArray1d(attrname_real, len(self.attributes), True)
|
||||
self.y_real = state[attrname_real]
|
||||
|
||||
state.newArray1d(attrname_var, len(self.attributes), True)
|
||||
self.y = state[attrname_var]
|
||||
|
||||
def _loaded():
|
||||
myprint(f"Reinjecting parameters for bias sampler: {self.attributes}")
|
||||
self._update_attrs(self.y)
|
||||
# Ensure `self.y` is a JAX array
|
||||
self._update_attrs(jnp.array(self.y))
|
||||
|
||||
state.subscribeLoaded(attrname, _loaded)
|
||||
|
||||
mvsname = f"{self.prefix}mvs_state"
|
||||
state.newArray1d(mvsname, len(self.attributes), True)
|
||||
self.mvs_state = state[mvsname]
|
||||
# Subscribe to updates
|
||||
state.subscribeLoaded(attrname_var, _loaded)
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import jax
|
|||
import jax.numpy as jnp
|
||||
import scipy.integrate
|
||||
import math
|
||||
import re
|
||||
|
||||
# Output stream management
|
||||
cons = borg.console()
|
||||
|
@ -102,12 +103,10 @@ def get_action():
|
|||
# Check that the this is the line was want and find action
|
||||
cmd = "hades_python"
|
||||
assert cmd in last_line
|
||||
idx = last_line.index(cmd)
|
||||
last_line = last_line[idx+len(cmd):]
|
||||
while last_line[0] == ' ':
|
||||
last_line = last_line[1:]
|
||||
idx = last_line.index(' ')
|
||||
last_line = last_line[:idx].upper()
|
||||
# Pattern matches the command followed by a word (non-space characters), captures the word after the last occurrence
|
||||
pattern = rf'(?:{re.escape(cmd)})(?=\s)(?:\s+(\S+))'
|
||||
matches = re.findall(pattern, last_line)
|
||||
last_line = matches[-1]
|
||||
myprint(f'Running BORG mode: {last_line}')
|
||||
|
||||
return last_line
|
||||
|
|
|
@ -20,15 +20,15 @@ bias_sampler_blocked= true
|
|||
nmean_sampler_blocked= true
|
||||
sigma8_sampler_blocked = true
|
||||
omega_m_sampler_blocked = true
|
||||
muA_sampler_blocked = false
|
||||
alpha_sampler_blocked = false
|
||||
lam_sampler_blocked = false
|
||||
sig_v_sampler_blocked = false
|
||||
muA_sampler_blocked = true
|
||||
alpha_sampler_blocked = true
|
||||
lam_sampler_blocked = true
|
||||
sig_v_sampler_blocked = true
|
||||
bulk_flow_sampler_blocked = false
|
||||
ares_heat = 1.0
|
||||
|
||||
[mcmc]
|
||||
number_to_generate = 10000
|
||||
number_to_generate = 3
|
||||
warmup_model = 0
|
||||
warmup_cosmo = 0
|
||||
random_ic = false
|
||||
|
@ -42,13 +42,18 @@ max_timesteps = 50
|
|||
mixing = 1
|
||||
|
||||
[sampling]
|
||||
algorithm = blackjax
|
||||
algorithm = BlackJax
|
||||
epsilon = 0.005
|
||||
Nsteps = 20
|
||||
refresh = 0.1
|
||||
rng_seed = 1
|
||||
warmup_nsteps = 100
|
||||
warmup_target_acceptance_rate = 0.7
|
||||
mvlam_mua = 0.01
|
||||
mvlam_alpha = 0.5
|
||||
mvlam_lam = 1.0
|
||||
mvlam_bulk_flow = 20
|
||||
mvlam_sig_v = 10
|
||||
|
||||
[model]
|
||||
gravity = lpt
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 98 KiB After Width: | Height: | Size: 2.9 MiB |
BIN
figs/trace.png
BIN
figs/trace.png
Binary file not shown.
Before Width: | Height: | Size: 305 KiB After Width: | Height: | Size: 365 KiB |
File diff suppressed because one or more lines are too long
|
@ -18,7 +18,7 @@ set -e
|
|||
|
||||
# Path variables
|
||||
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
|
||||
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/test_dir_blackjax
|
||||
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/test_dir_blackjax_restart
|
||||
|
||||
mkdir -p $RUN_DIR
|
||||
cd $RUN_DIR
|
||||
|
@ -33,4 +33,5 @@ set -x
|
|||
# Just ICs
|
||||
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
|
||||
cp $INI_FILE ini_file.ini
|
||||
$BORG INIT ini_file.ini
|
||||
# $BORG INIT ini_file.ini
|
||||
$BORG RESUME ini_file.ini
|
||||
|
|
Loading…
Add table
Reference in a new issue