Add diagnosis plots for model parameter tuning
This commit is contained in:
parent
e16b14e867
commit
ea7019b2db
6 changed files with 44 additions and 4 deletions
|
@ -5,6 +5,8 @@ import jax.numpy as jnp
|
|||
import jax
|
||||
import blackjax
|
||||
from borg_velocity.utils import myprint
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
|
||||
|
||||
class HMCBiasSampler(borg.samplers.PyBaseSampler):
|
||||
|
@ -352,7 +354,14 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
|||
)
|
||||
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
|
||||
(bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
|
||||
# x = info.state.position
|
||||
x = info.state.position
|
||||
|
||||
# Make and save summary figures
|
||||
dirname = os.getcwd()
|
||||
trace_plot(self.attributes, info.state.position, f'{dirname}/tuning_params.png', ncol=4)
|
||||
trace_plot(self.attributes, info.adaptation_state.inverse_mass_matrix, f'{dirname}/tuning_inverse_mass_matrix.png', ncol=4)
|
||||
trace_plot(['stepsize'], info.adaptation_state.step_size, f'{dirname}/tuning_step_size.png', ncol=4)
|
||||
|
||||
self.y[:] = bj_state.position
|
||||
|
||||
# Save results to the borg state object
|
||||
|
@ -562,3 +571,34 @@ class TransformedBlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
|||
|
||||
|
||||
|
||||
def trace_plot(names, samples, savename, ncol=4):
|
||||
|
||||
if len(names) == 1:
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(4,3))
|
||||
ax.plot(samples[:])
|
||||
ax.set_title(names[0])
|
||||
ax.set_xlabel('Step Number')
|
||||
|
||||
else:
|
||||
|
||||
nrow = int(np.ceil(len(names) / ncol))
|
||||
|
||||
fig, axs = plt.subplots(nrow, ncol, figsize=(ncol*4, nrow*3))
|
||||
axs = np.atleast_2d(axs)
|
||||
axs = axs.flatten()
|
||||
|
||||
for i in range(len(names)):
|
||||
axs[i].plot(samples[:,i])
|
||||
axs[i].set_title(names[i])
|
||||
axs[i].set_xlabel('Step Number')
|
||||
for i in range(len(names), len(axs)):
|
||||
axs[i].remove()
|
||||
|
||||
fig.tight_layout()
|
||||
myprint(f'Saving tuning diagnosis figure: {savename}')
|
||||
fig.savefig(savename, facecolor='white', bbox_inches='tight')
|
||||
fig.clf()
|
||||
plt.close(fig)
|
||||
|
||||
return
|
||||
|
|
|
@ -47,7 +47,7 @@ epsilon = 0.005
|
|||
Nsteps = 20
|
||||
refresh = 0.1
|
||||
rng_seed = 1
|
||||
warmup_nsteps = 100
|
||||
warmup_nsteps = 200
|
||||
warmup_target_acceptance_rate = 0.7
|
||||
mvlam_mua = 0.01
|
||||
mvlam_alpha = 0.5
|
||||
|
|
BIN
figs/tuning_inverse_mass_matrix.png
Normal file
BIN
figs/tuning_inverse_mass_matrix.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
BIN
figs/tuning_params.png
Normal file
BIN
figs/tuning_params.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 84 KiB |
BIN
figs/tuning_step_size.png
Normal file
BIN
figs/tuning_step_size.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 28 KiB |
|
@ -33,5 +33,5 @@ set -x
|
|||
# Just ICs
|
||||
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
|
||||
cp $INI_FILE ini_file.ini
|
||||
# $BORG INIT ini_file.ini
|
||||
$BORG RESUME ini_file.ini
|
||||
$BORG INIT ini_file.ini
|
||||
# $BORG RESUME ini_file.ini
|
||||
|
|
Loading…
Add table
Reference in a new issue