Enable restarts in BlackJax samplers and with reloading mock data

This commit is contained in:
Deaglan Bartlett 2024-10-28 16:49:26 +01:00
parent 8a0c638e5b
commit e16b14e867
8 changed files with 307 additions and 52 deletions

View file

@ -12,13 +12,14 @@ from functools import partial
import ast
import numbers
import h5py
import re
import borg_velocity.utils as utils
from borg_velocity.utils import myprint, compute_As, get_sigma_bulk
import borg_velocity.forwards as forwards
import borg_velocity.mock_maker as mock_maker
import borg_velocity.projection as projection
from borg_velocity.samplers import HMCBiasSampler, derive_prior, MVSliceBiasSampler, BlackJaxBiasSampler
from borg_velocity.samplers import HMCBiasSampler, derive_prior, MVSliceBiasSampler, BlackJaxBiasSampler, TransformedBlackJaxBiasSampler
import borg_velocity.samplers as samplers
class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
@ -112,6 +113,12 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
# Initialise derivative
self.grad_like = jax.grad(self.dens2like, argnums=(0,1))
# Find out what kind of run this is
if borg.EMBEDDED:
self.action = utils.get_action()
else:
self.action = None
def initializeLikelihood(self, state: borg.likelihood.MarkovState) -> None:
"""
@ -135,6 +142,17 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
# self.data = [state[f"tracer_vr_{i}"] for i in range(self.nsamp)]
state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True)
if self.run_type == 'data':
self.loadObservedData(make_plot=False)
else:
if self.action == 'INIT':
pass # Data will be loaded later
elif self.action == 'RESUME':
self.loadMockData(state) # load data from mock_data.h5
else:
myprint(f'Unknown action: {self.action}')
raise NotImplementedError
def updateMetaParameters(self, state: borg.likelihood.MarkovState) -> None:
"""
@ -246,6 +264,30 @@ class VelocityBORGLikelihood(borg.likelihood.BaseLikelihood):
myprint('From mock')
self.saved_s_hat = s_hat.copy()
self.logLikelihoodComplex(s_hat, False)
def loadMockData(self, state: borg.likelihood.MarkovState) -> None:
myprint('Loading previously made mock data')
self.coord_true = [None] * self.nsamp
self.coord_meas = [None] * self.nsamp
self.sig_mu = [None] * self.nsamp
self.vr_true = [None] * self.nsamp
self.cz_obs = [None] * self.nsamp
self.r_hMpc = [None] * self.nsamp
with h5py.File(f'tracer_data_{self.run_type}.h5', 'r') as f:
for i in range(self.nsamp):
self.coord_true[i] = jnp.array(f[f'sample_{i}/coord_true'][:])
self.coord_meas[i] = jnp.array(f[f'sample_{i}/coord_meas'][:])
self.sig_mu[i] = f[f'sample_{i}/sig_mu'][()]
self.vr_true[i] = jnp.array(f[f'sample_{i}/vr_true'][:])
self.cz_obs[i] = jnp.array(f[f'sample_{i}/cz_obs'][:])
self.r_hMpc[i] = np.sqrt(np.sum(self.coord_meas[i] ** 2, axis=0))
self.generateMBData()
def dens2like(self, output_density: np.ndarray, output_velocity: np.ndarray):
"""
@ -786,15 +828,28 @@ def build_sampler(
for item in params:
file.write(f"{item}\n")
elif config['sampling']['algorithm'].lower() == 'mvslice':
lam = [None] * len(params)
for i, p in enumerate(params):
if p.startswith('bulk_flow'):
lam[i] = float(config['sampling'][f'mvlam_bulk_flow'])
elif p[-1].isdigit():
# Search for a non-digit before end of reversed string
match = re.search(r'\D(?=\d*$)', p[::-1])
lam[i] = float(config['sampling'][f'mvlam_{p[:match.start()]}'])
else:
lam[i] = float(config['sampling'][f'mvlam_{p}'])
model_sampler = MVSliceBiasSampler(
"model_params",
likelihood,
params,
lam
)
with open('model_params.txt', 'w') as file:
for item in params:
file.write(f"{item}\n")
elif config['sampling']['algorithm'].lower() == 'blackjax':
# state.newArray1d("inverse_mass_matrix", len(params), True)
# state.newScalar('step_size', 0., True)
model_sampler = BlackJaxBiasSampler(
"model_params",
likelihood,
@ -806,6 +861,21 @@ def build_sampler(
with open('model_params.txt', 'w') as file:
for item in params:
file.write(f"{item}\n")
elif config['sampling']['algorithm'].lower() == 'transformedblackjax':
model_sampler = TransformedBlackJaxBiasSampler(
"model_params",
likelihood,
params,
transform_attributes=transform_attributes,
inv_transform_attributes=inv_transform_attributes,
prior = [p.prior for p in transform_attributes],
rng_seed = int(config['sampling']['rng_seed']),
warmup_nsteps = int(config['sampling']['warmup_nsteps']),
warmup_target_acceptance_rate = float(config['sampling']['warmup_target_acceptance_rate']),
)
with open('model_params.txt', 'w') as file:
for item in params:
file.write(f"{item}\n")
else:
raise NotImplementedError
model_sampler.setName("model_sampler")

View file

@ -214,12 +214,13 @@ class MVSliceBiasSampler(borg.samplers.PyBaseSampler):
prefix: str,
likelihood: borg.likelihood.Likelihood3d,
attributes: List[str],
lam: List[float],
):
super().__init__()
self.likelihood = likelihood
self.attributes = attributes
self.prefix = prefix
self.lam = 1.0
self.lam = np.array(lam)
def initialize(self, state: borg.likelihood.MarkovState):
self.restore(state)
@ -240,7 +241,6 @@ class MVSliceBiasSampler(borg.samplers.PyBaseSampler):
x_hat = state["s_hat_field"]
def _callback(x):
# print("Callback: " + repr(x))
self._update_attrs(x)
return -self.likelihood.logLikelihoodComplex(x_hat, False)
@ -282,15 +282,10 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
self.rng_key = jax.random.key(rng_seed)
self.warmup_nsteps = warmup_nsteps
self.warmup_target_acceptance_rate = warmup_target_acceptance_rate
# CHANGE THIS
# self.step_size = 1e-4
# self.inverse_mass_matrix = np.ones(len(attributes))
self.parameters = None
print("MY ATTRIBUTES:", attributes)
def initialize(self, state: borg.likelihood.MarkovState):
print(dir(state))
self.restore(state)
for i in range(len(self.attributes)):
self.y[i] = self.likelihood.fwd_param.getModelParam(
@ -313,6 +308,176 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
def sample(self, state: borg.likelihood.MarkovState):
myprint(f"Sampling attributes {self.attributes}")
# Update the sampling hyperparameters if possible
if (self.parameters is None) and (state["step_size"] > 0) and all(state["inverse_mass_matrix"] > 0):
myprint('Using loaded hyperparameters for model parameter sampling')
myprint(f'Step size: {state["step_size"]}')
myprint(f'Inverse mass matrix: {state["inverse_mass_matrix"]}')
self.parameters = {"step_size": state["step_size"], "inverse_mass_matrix":state["inverse_mass_matrix"]}
x_hat = state["s_hat_field"]
# Run forward model to compute correct density field
_ = self.likelihood.logLikelihoodComplex(x_hat, False, skip_density=False)
@jax.custom_vjp
def logdensity_fn(x):
return - self._likelihood(x_hat, x, skip_density=True)
def logdensity_fn_fwd(x):
y = logdensity_fn(x)
return y, x # Save `x` for the backward pass
def logdensity_fn_bwd(x, cotangent_y):
grad_fn = jax.grad(self._likelihood, argnums=1)
grad_x = - grad_fn(x_hat, x, skip_density=True)
cotangent_x = grad_x * cotangent_y
return (cotangent_x,)
logdensity_fn.defvjp(logdensity_fn_fwd, logdensity_fn_bwd)
if self.parameters is None:
get_info = blackjax.adaptation.base.get_filter_adapt_info_fn(
state_keys={'position', 'logdensity'},
info_keys={'acceptance_rate'},
adapt_state_keys={'step_size', 'inverse_mass_matrix'}
)
logdensity_fn_jitted = jax.jit(logdensity_fn)
warmup = blackjax.window_adaptation(
blackjax.nuts,
logdensity_fn_jitted,
target_acceptance_rate=self.warmup_target_acceptance_rate,
progress_bar=True,
adaptation_info_fn=get_info
)
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
(bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
# x = info.state.position
self.y[:] = bj_state.position
# Save results to the borg state object
# state.newArray1d("inverse_mass_matrix", len(self.parameters['inverse_mass_matrix']), True)
state["inverse_mass_matrix"][:] = self.parameters['inverse_mass_matrix']
state["step_size"] = float(self.parameters['step_size'])
nuts = blackjax.nuts(logdensity_fn, **self.parameters)
bias_state = nuts.init(self.y)
nuts_step = nuts.step
self.rng_key, sample_key = jax.random.split(self.rng_key)
bias_state, _ = nuts_step(sample_key, bias_state)
self.y[:] = bias_state.position
self._update_attrs(self.y)
to_set = {k:v for k, v in zip(self.attributes, self.y)}
self.likelihood.fwd_param.setModelParams(to_set)
def restore(self, state: borg.likelihood.MarkovState):
state.newArray1d("inverse_mass_matrix", len(self.attributes), True)
state.newScalar('step_size', 0., True)
attrname = f"{self.prefix}attributes"
state.newArray1d(attrname, len(self.attributes), True)
self.y = state[attrname]
def _loaded():
myprint(f"Reinjecting parameters for bias sampler: {self.attributes}")
self._update_attrs(self.y)
state.subscribeLoaded(attrname, _loaded)
mvsname = f"{self.prefix}mvs_state"
state.newArray1d(mvsname, len(self.attributes), True)
self.mvs_state = state[mvsname]
class TransformedBlackJaxBiasSampler(borg.samplers.PyBaseSampler):
def __init__(
self,
prefix: str,
likelihood: borg.likelihood.BaseLikelihood,
attributes: List[str],
transform_attributes: List[object],
inv_transform_attributes: List[object],
prior: List[object],
rng_seed: int,
warmup_nsteps: int,
warmup_target_acceptance_rate: float
):
super().__init__()
self.likelihood = likelihood
self.attributes = attributes
self.prefix = prefix
self.transform_attributes = transform_attributes
self.inv_transform_attributes = inv_transform_attributes
self.prior_attributes = prior
self.rng_key = jax.random.key(rng_seed)
self.warmup_nsteps = warmup_nsteps
self.warmup_target_acceptance_rate = warmup_target_acceptance_rate
self.parameters = None
def initialize(self, state: borg.likelihood.MarkovState):
self.restore(state)
for i in range(len(self.attributes)):
attr = self.likelihood.fwd_param.getModelParam(
'mod_null',
self.attributes[i]
)
self.y_real[i] = attr
if isinstance(attr, float):
attr = jnp.array(attr)
if self.inv_transform_attributes[i]:
self.y[i] = self.inv_transform_attributes[i](attr).item()
else:
self.y[i] = attr.item()
myprint(f"initial {self.y=}")
def _update_attrs(self, x):
for i in range(len(self.attributes)):
xnew = x[i]
if self.transform_attributes[i]:
xnew = self.transform_attributes[i](xnew)
self.likelihood.model_params[self.attributes[i]] = xnew
def _likelihood(self, x_hat, x, skip_density=False):
self._update_attrs(x)
return self.likelihood.logLikelihoodComplex(
x_hat, False, skip_density=skip_density, update_from_model=False,
)
def _grad_likelihood(self, x_hat, x, skip_density=False):
grad_fn = jax.grad(self._likelihood, argnums=1)
grad_x = grad_fn(x_hat, x, skip_density=skip_density)
return grad_x
def _prior(self, x):
prior_value = 0.0
for i in range(len(self.attributes)):
if self.prior_attributes[i]:
prior_value += self.prior_attributes[i](x[i])
return prior_value
def _grad_prior(self, x):
grad_x = jax.grad(self._prior, argnums=0)(x)
return grad_x
def sample(self, state: borg.likelihood.MarkovState):
myprint(f"Sampling attributes {self.attributes}")
# Update the sampling hyperparameters if possible
if (self.parameters is None) and (state["step_size"] > 0) and all(state["inverse_mass_matrix"] > 0):
myprint('Using loaded hyperparameters for model parameter sampling')
myprint(f'Step size: {state["step_size"]}')
myprint(f'Inverse mass matrix: {state["inverse_mass_matrix"]}')
self.parameters = {"step_size": state["step_size"], "inverse_mass_matrix":state["inverse_mass_matrix"]}
x_hat = state["s_hat_field"]
# Run forward model to compute correct density field
@ -321,15 +486,15 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
@jax.custom_vjp
def logdensity_fn(x):
self._update_attrs(x)
return -self.likelihood.logLikelihoodComplex(x_hat, False, skip_density=True, update_from_model=False)
p = self._likelihood(x_hat, x, skip_density=True) + self._prior(x)
return - p
def logdensity_fn_fwd(x):
y = logdensity_fn(x)
return y, x # Save `x` for the backward pass
def logdensity_fn_bwd(x, cotangent_y):
grad_fn = jax.grad(self._likelihood, argnums=1)
grad_x = grad_fn(x_hat, x, skip_density=True)
grad_x = - (self._grad_likelihood(x_hat, x, skip_density=True) + self._grad_prior(x))
cotangent_x = grad_x * cotangent_y
return (cotangent_x,)
@ -363,25 +528,37 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
self.y[:] = bias_state.position
self._update_attrs(self.y)
for i in range(len(self.attributes)):
self.y_real[i] = self.likelihood.model_params[self.attributes[i]]
to_set = {k:v for k, v in zip(self.attributes, self.y)}
to_set = {k:v for k, v in zip(self.attributes, self.y_real)}
self.likelihood.fwd_param.setModelParams(to_set)
def restore(self, state: borg.likelihood.MarkovState):
attrname = f"{self.prefix}attributes"
state.newArray1d(attrname, len(self.attributes), True)
self.y = state[attrname]
state.newArray1d("inverse_mass_matrix", len(self.attributes), True)
state.newScalar('step_size', 0., True)
# Define attribute names
attrname_real = self.prefix
attrname_var = f"{self.prefix}_hmc_var"
# Initialize attributes in the state
state.newArray1d(attrname_real, len(self.attributes), True)
self.y_real = state[attrname_real]
state.newArray1d(attrname_var, len(self.attributes), True)
self.y = state[attrname_var]
def _loaded():
myprint(f"Reinjecting parameters for bias sampler: {self.attributes}")
self._update_attrs(self.y)
# Ensure `self.y` is a JAX array
self._update_attrs(jnp.array(self.y))
state.subscribeLoaded(attrname, _loaded)
mvsname = f"{self.prefix}mvs_state"
state.newArray1d(mvsname, len(self.attributes), True)
self.mvs_state = state[mvsname]
# Subscribe to updates
state.subscribeLoaded(attrname_var, _loaded)

View file

@ -7,6 +7,7 @@ import jax
import jax.numpy as jnp
import scipy.integrate
import math
import re
# Output stream management
cons = borg.console()
@ -102,12 +103,10 @@ def get_action():
# Check that the this is the line was want and find action
cmd = "hades_python"
assert cmd in last_line
idx = last_line.index(cmd)
last_line = last_line[idx+len(cmd):]
while last_line[0] == ' ':
last_line = last_line[1:]
idx = last_line.index(' ')
last_line = last_line[:idx].upper()
# Pattern matches the command followed by a word (non-space characters), captures the word after the last occurrence
pattern = rf'(?:{re.escape(cmd)})(?=\s)(?:\s+(\S+))'
matches = re.findall(pattern, last_line)
last_line = matches[-1]
myprint(f'Running BORG mode: {last_line}')
return last_line

View file

@ -20,15 +20,15 @@ bias_sampler_blocked= true
nmean_sampler_blocked= true
sigma8_sampler_blocked = true
omega_m_sampler_blocked = true
muA_sampler_blocked = false
alpha_sampler_blocked = false
lam_sampler_blocked = false
sig_v_sampler_blocked = false
muA_sampler_blocked = true
alpha_sampler_blocked = true
lam_sampler_blocked = true
sig_v_sampler_blocked = true
bulk_flow_sampler_blocked = false
ares_heat = 1.0
[mcmc]
number_to_generate = 10000
number_to_generate = 3
warmup_model = 0
warmup_cosmo = 0
random_ic = false
@ -42,13 +42,18 @@ max_timesteps = 50
mixing = 1
[sampling]
algorithm = blackjax
algorithm = BlackJax
epsilon = 0.005
Nsteps = 20
refresh = 0.1
rng_seed = 1
warmup_nsteps = 100
warmup_target_acceptance_rate = 0.7
mvlam_mua = 0.01
mvlam_alpha = 0.5
mvlam_lam = 1.0
mvlam_bulk_flow = 20
mvlam_sig_v = 10
[model]
gravity = lpt

Binary file not shown.

Before

Width:  |  Height:  |  Size: 98 KiB

After

Width:  |  Height:  |  Size: 2.9 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 305 KiB

After

Width:  |  Height:  |  Size: 365 KiB

File diff suppressed because one or more lines are too long

View file

@ -18,7 +18,7 @@ set -e
# Path variables
BORG=/data101/bartlett/build_borg/tools/hades_python/hades_python
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/test_dir_blackjax
RUN_DIR=/data101/bartlett/fsigma8/borg_velocity/test_dir_blackjax_restart
mkdir -p $RUN_DIR
cd $RUN_DIR
@ -33,4 +33,5 @@ set -x
# Just ICs
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
cp $INI_FILE ini_file.ini
$BORG INIT ini_file.ini
# $BORG INIT ini_file.ini
$BORG RESUME ini_file.ini