diff --git a/borg_velocity/samplers.py b/borg_velocity/samplers.py index 17b4b70..a257cf9 100644 --- a/borg_velocity/samplers.py +++ b/borg_velocity/samplers.py @@ -5,6 +5,8 @@ import jax.numpy as jnp import jax import blackjax from borg_velocity.utils import myprint +import matplotlib.pyplot as plt +import os class HMCBiasSampler(borg.samplers.PyBaseSampler): @@ -352,7 +354,14 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler): ) rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3) (bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps) - # x = info.state.position + x = info.state.position + + # Make and save summary figures + dirname = os.getcwd() + trace_plot(self.attributes, info.state.position, f'{dirname}/tuning_params.png', ncol=4) + trace_plot(self.attributes, info.adaptation_state.inverse_mass_matrix, f'{dirname}/tuning_inverse_mass_matrix.png', ncol=4) + trace_plot(['stepsize'], info.adaptation_state.step_size, f'{dirname}/tuning_step_size.png', ncol=4) + self.y[:] = bj_state.position # Save results to the borg state object @@ -562,3 +571,34 @@ class TransformedBlackJaxBiasSampler(borg.samplers.PyBaseSampler): +def trace_plot(names, samples, savename, ncol=4): + + if len(names) == 1: + + fig, ax = plt.subplots(1, 1, figsize=(4,3)) + ax.plot(samples[:]) + ax.set_title(names[0]) + ax.set_xlabel('Step Number') + + else: + + nrow = int(np.ceil(len(names) / ncol)) + + fig, axs = plt.subplots(nrow, ncol, figsize=(ncol*4, nrow*3)) + axs = np.atleast_2d(axs) + axs = axs.flatten() + + for i in range(len(names)): + axs[i].plot(samples[:,i]) + axs[i].set_title(names[i]) + axs[i].set_xlabel('Step Number') + for i in range(len(names), len(axs)): + axs[i].remove() + + fig.tight_layout() + myprint(f'Saving tuning diagnosis figure: {savename}') + fig.savefig(savename, facecolor='white', bbox_inches='tight') + fig.clf() + plt.close(fig) + + return diff --git a/conf/supranta_ini.ini b/conf/supranta_ini.ini index f0b21a9..4ab318e 100644 --- a/conf/supranta_ini.ini +++ b/conf/supranta_ini.ini @@ -47,7 +47,7 @@ epsilon = 0.005 Nsteps = 20 refresh = 0.1 rng_seed = 1 -warmup_nsteps = 100 +warmup_nsteps = 200 warmup_target_acceptance_rate = 0.7 mvlam_mua = 0.01 mvlam_alpha = 0.5 diff --git a/figs/tuning_inverse_mass_matrix.png b/figs/tuning_inverse_mass_matrix.png new file mode 100644 index 0000000..4328516 Binary files /dev/null and b/figs/tuning_inverse_mass_matrix.png differ diff --git a/figs/tuning_params.png b/figs/tuning_params.png new file mode 100644 index 0000000..1102196 Binary files /dev/null and b/figs/tuning_params.png differ diff --git a/figs/tuning_step_size.png b/figs/tuning_step_size.png new file mode 100644 index 0000000..2cd1c39 Binary files /dev/null and b/figs/tuning_step_size.png differ diff --git a/scripts/run_borg.sh b/scripts/run_borg.sh index 8d7dfd0..e0721cc 100755 --- a/scripts/run_borg.sh +++ b/scripts/run_borg.sh @@ -33,5 +33,5 @@ set -x # Just ICs INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini cp $INI_FILE ini_file.ini -# $BORG INIT ini_file.ini -$BORG RESUME ini_file.ini +$BORG INIT ini_file.ini +# $BORG RESUME ini_file.ini