Add diagnosis plots for model parameter tuning
This commit is contained in:
parent
e16b14e867
commit
ea7019b2db
6 changed files with 44 additions and 4 deletions
|
@ -5,6 +5,8 @@ import jax.numpy as jnp
|
||||||
import jax
|
import jax
|
||||||
import blackjax
|
import blackjax
|
||||||
from borg_velocity.utils import myprint
|
from borg_velocity.utils import myprint
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
class HMCBiasSampler(borg.samplers.PyBaseSampler):
|
class HMCBiasSampler(borg.samplers.PyBaseSampler):
|
||||||
|
@ -352,7 +354,14 @@ class BlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
||||||
)
|
)
|
||||||
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
|
rng_key, warmup_key, sample_key = jax.random.split(self.rng_key, 3)
|
||||||
(bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
|
(bj_state, self.parameters), info = warmup.run(warmup_key, self.y, num_steps=self.warmup_nsteps)
|
||||||
# x = info.state.position
|
x = info.state.position
|
||||||
|
|
||||||
|
# Make and save summary figures
|
||||||
|
dirname = os.getcwd()
|
||||||
|
trace_plot(self.attributes, info.state.position, f'{dirname}/tuning_params.png', ncol=4)
|
||||||
|
trace_plot(self.attributes, info.adaptation_state.inverse_mass_matrix, f'{dirname}/tuning_inverse_mass_matrix.png', ncol=4)
|
||||||
|
trace_plot(['stepsize'], info.adaptation_state.step_size, f'{dirname}/tuning_step_size.png', ncol=4)
|
||||||
|
|
||||||
self.y[:] = bj_state.position
|
self.y[:] = bj_state.position
|
||||||
|
|
||||||
# Save results to the borg state object
|
# Save results to the borg state object
|
||||||
|
@ -562,3 +571,34 @@ class TransformedBlackJaxBiasSampler(borg.samplers.PyBaseSampler):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def trace_plot(names, samples, savename, ncol=4):
|
||||||
|
|
||||||
|
if len(names) == 1:
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(1, 1, figsize=(4,3))
|
||||||
|
ax.plot(samples[:])
|
||||||
|
ax.set_title(names[0])
|
||||||
|
ax.set_xlabel('Step Number')
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
nrow = int(np.ceil(len(names) / ncol))
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(nrow, ncol, figsize=(ncol*4, nrow*3))
|
||||||
|
axs = np.atleast_2d(axs)
|
||||||
|
axs = axs.flatten()
|
||||||
|
|
||||||
|
for i in range(len(names)):
|
||||||
|
axs[i].plot(samples[:,i])
|
||||||
|
axs[i].set_title(names[i])
|
||||||
|
axs[i].set_xlabel('Step Number')
|
||||||
|
for i in range(len(names), len(axs)):
|
||||||
|
axs[i].remove()
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
myprint(f'Saving tuning diagnosis figure: {savename}')
|
||||||
|
fig.savefig(savename, facecolor='white', bbox_inches='tight')
|
||||||
|
fig.clf()
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
|
@ -47,7 +47,7 @@ epsilon = 0.005
|
||||||
Nsteps = 20
|
Nsteps = 20
|
||||||
refresh = 0.1
|
refresh = 0.1
|
||||||
rng_seed = 1
|
rng_seed = 1
|
||||||
warmup_nsteps = 100
|
warmup_nsteps = 200
|
||||||
warmup_target_acceptance_rate = 0.7
|
warmup_target_acceptance_rate = 0.7
|
||||||
mvlam_mua = 0.01
|
mvlam_mua = 0.01
|
||||||
mvlam_alpha = 0.5
|
mvlam_alpha = 0.5
|
||||||
|
|
BIN
figs/tuning_inverse_mass_matrix.png
Normal file
BIN
figs/tuning_inverse_mass_matrix.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 20 KiB |
BIN
figs/tuning_params.png
Normal file
BIN
figs/tuning_params.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 84 KiB |
BIN
figs/tuning_step_size.png
Normal file
BIN
figs/tuning_step_size.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 28 KiB |
|
@ -33,5 +33,5 @@ set -x
|
||||||
# Just ICs
|
# Just ICs
|
||||||
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
|
INI_FILE=/home/bartlett/fsigma8/borg_velocity/conf/supranta_ini.ini
|
||||||
cp $INI_FILE ini_file.ini
|
cp $INI_FILE ini_file.ini
|
||||||
# $BORG INIT ini_file.ini
|
$BORG INIT ini_file.ini
|
||||||
$BORG RESUME ini_file.ini
|
# $BORG RESUME ini_file.ini
|
||||||
|
|
Loading…
Add table
Reference in a new issue