Add diagnosis plots for model parameter tuning

This commit is contained in:
Deaglan Bartlett 2024-10-28 17:43:30 +01:00
parent e16b14e867
commit ea7019b2db
6 changed files with 44 additions and 4 deletions

View file

@ -5,6 +5,8 @@ import jax.numpy as jnp
import jax
import blackjax
from borg_velocity.utils import myprint
import matplotlib.pyplot as plt
import os
class HMCBiasSampler(borg.samplers.PyBaseSampler):
@ -352,7 +354,14 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
)
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
(bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
# x = info.state.position
x = info.state.position
# Make and save summary figures
dirname = os.getcwd()
trace_plot(self.attributes, info.state.position, f'{dirname}/tuning_params.png', ncol=4)
trace_plot(self.attributes, info.adaptation_state.inverse_mass_matrix, f'{dirname}/tuning_inverse_mass_matrix.png', ncol=4)
trace_plot(['stepsize'], info.adaptation_state.step_size, f'{dirname}/tuning_step_size.png', ncol=4)
self.y[:] = bj_state.position
# Save results to the borg state object
@ -562,3 +571,34 @@ class TransformedBlackJaxBiasSampler(borg.samplers.PyBaseSampler):
def trace_plot(names, samples, savename, ncol=4):
if len(names) == 1:
fig, ax = plt.subplots(1, 1, figsize=(4,3))
ax.plot(samples[:])
ax.set_title(names[0])
ax.set_xlabel('Step Number')
else:
nrow = int(np.ceil(len(names) / ncol))
fig, axs = plt.subplots(nrow, ncol, figsize=(ncol*4, nrow*3))
axs = np.atleast_2d(axs)
axs = axs.flatten()
for i in range(len(names)):
axs[i].plot(samples[:,i])
axs[i].set_title(names[i])
axs[i].set_xlabel('Step Number')
for i in range(len(names), len(axs)):
axs[i].remove()
fig.tight_layout()
myprint(f'Saving tuning diagnosis figure: {savename}')
fig.savefig(savename, facecolor='white', bbox_inches='tight')
fig.clf()
plt.close(fig)
return

View file

@ -47,7 +47,7 @@ epsilon = 0.005
Nsteps = 20
refresh = 0.1
rng_seed = 1
warmup_nsteps = 100
warmup_nsteps = 200
warmup_target_acceptance_rate = 0.7
mvlam_mua = 0.01
mvlam_alpha = 0.5

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

BIN
figs/tuning_params.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

BIN
figs/tuning_step_size.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

View file

@ -33,5 +33,5 @@ set -x
# Just ICs
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
cp $INI_FILE ini_file.ini
# $BORG INIT ini_file.ini
$BORG RESUME ini_file.ini
$BORG INIT ini_file.ini
# $BORG RESUME ini_file.ini