Add better looking tensorboard images
This commit is contained in:
parent
996c0d3aed
commit
a88f27a3a1
@ -9,7 +9,7 @@ from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
|||||||
from matplotlib.cm import ScalarMappable
|
from matplotlib.cm import ScalarMappable
|
||||||
|
|
||||||
|
|
||||||
def fig3d(*fields, size=64, cmap=None, norm=None):
|
def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
|
||||||
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
||||||
else field for field in fields]
|
else field for field in fields]
|
||||||
|
|
||||||
@ -18,9 +18,18 @@ def fig3d(*fields, size=64, cmap=None, norm=None):
|
|||||||
nc = max(field.shape[0] for field in fields)
|
nc = max(field.shape[0] for field in fields)
|
||||||
nf = len(fields)
|
nf = len(fields)
|
||||||
|
|
||||||
colorbar_frac = 0.15 / (0.85 * nc + 0.15)
|
if title is not None:
|
||||||
fig, axes = plt.subplots(nc, nf, squeeze=False,
|
assert len(title) == nf
|
||||||
figsize=(4 * nf, 4 * nc * (1 + colorbar_frac)))
|
|
||||||
|
im_size = 3
|
||||||
|
cbar_height = 0.5
|
||||||
|
cbar_frac = cbar_height / (nc * im_size + cbar_height)
|
||||||
|
fig, axes = plt.subplots(
|
||||||
|
nc, nf,
|
||||||
|
squeeze=False,
|
||||||
|
figsize=(nf * im_size, nc * im_size + cbar_height),
|
||||||
|
constrained_layout=True,
|
||||||
|
)
|
||||||
|
|
||||||
def quantize(x):
|
def quantize(x):
|
||||||
return 2 ** round(log2(x), ndigits=1)
|
return 2 ** round(log2(x), ndigits=1)
|
||||||
@ -63,11 +72,28 @@ def fig3d(*fields, size=64, cmap=None, norm=None):
|
|||||||
norm_ = norm
|
norm_ = norm
|
||||||
|
|
||||||
for c in range(field.shape[0]):
|
for c in range(field.shape[0]):
|
||||||
axes[c, f].imshow(field[c, 0, :size, :size], cmap=cmap_, norm=norm_)
|
axes[c, f].pcolormesh(field[c, 0, :size, :size],
|
||||||
|
cmap=cmap_, norm=norm_)
|
||||||
|
|
||||||
|
axes[c, f].set_aspect('equal')
|
||||||
|
|
||||||
|
axes[c, f].set_xticks([])
|
||||||
|
axes[c, f].set_yticks([])
|
||||||
|
|
||||||
|
if c == 0 and title is not None:
|
||||||
|
axes[c, f].set_title(title[f])
|
||||||
|
|
||||||
for c in range(field.shape[0], nc):
|
for c in range(field.shape[0], nc):
|
||||||
axes[c, f].axis('off')
|
axes[c, f].axis('off')
|
||||||
|
|
||||||
plt.colorbar(ScalarMappable(norm=norm_, cmap=cmap_), ax=axes[:, f],
|
fig.colorbar(
|
||||||
orientation='horizontal', fraction=colorbar_frac, pad=0.05)
|
ScalarMappable(norm=norm_, cmap=cmap_),
|
||||||
|
ax=axes[:, f],
|
||||||
|
orientation='horizontal',
|
||||||
|
fraction=cbar_frac,
|
||||||
|
pad=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fig.set_constrained_layout_pads(w_pad=0, h_pad=0, wspace=0, hspace=0)
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
@ -423,6 +423,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
output[-1, skip_chan:],
|
output[-1, skip_chan:],
|
||||||
target[-1, skip_chan:],
|
target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
|
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||||
), global_step=epoch+1)
|
), global_step=epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
@ -499,6 +500,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
output[-1, skip_chan:],
|
output[-1, skip_chan:],
|
||||||
target[-1, skip_chan:],
|
target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:],
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
|
title=['in', 'out', 'tgt', 'out - tgt'],
|
||||||
), global_step=epoch+1)
|
), global_step=epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
Loading…
Reference in New Issue
Block a user