Fix training hang due to constrained layout of matplotlib
This commit is contained in:
parent
bf9b0ba426
commit
8ce13e67f6
@ -24,61 +24,58 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
|||||||
|
|
||||||
if title is not None:
|
if title is not None:
|
||||||
assert len(title) == nf
|
assert len(title) == nf
|
||||||
|
cmap = np.broadcast_to(cmap, (nf,))
|
||||||
|
norm = np.broadcast_to(norm, (nf,))
|
||||||
|
|
||||||
im_size = 2
|
im_size = 2
|
||||||
cbar_height = 0.3
|
cbar_height = 0.2
|
||||||
cbar_frac = cbar_height / (nc * im_size + cbar_height)
|
|
||||||
fig, axes = plt.subplots(
|
fig, axes = plt.subplots(
|
||||||
nc, nf,
|
nc + 1, nf,
|
||||||
squeeze=False,
|
squeeze=False,
|
||||||
figsize=(nf * im_size, nc * im_size + cbar_height),
|
figsize=(nf * im_size, nc * im_size + cbar_height),
|
||||||
dpi=100,
|
dpi=100,
|
||||||
constrained_layout=True,
|
gridspec_kw={'height_ratios': nc * [im_size] + [cbar_height]}
|
||||||
)
|
)
|
||||||
|
|
||||||
def quantize(x):
|
def quantize(x):
|
||||||
return 2 ** round(log2(x), ndigits=1)
|
return 2 ** round(log2(x), ndigits=1)
|
||||||
|
|
||||||
for f, field in enumerate(fields):
|
for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)):
|
||||||
all_non_neg = (field >= 0).all()
|
all_non_neg = (field >= 0).all()
|
||||||
all_non_pos = (field <= 0).all()
|
all_non_pos = (field <= 0).all()
|
||||||
|
|
||||||
if cmap is None:
|
if cmap_col is None:
|
||||||
if all_non_neg:
|
if all_non_neg:
|
||||||
cmap_ = 'viridis'
|
cmap_col = 'viridis'
|
||||||
elif all_non_pos:
|
elif all_non_pos:
|
||||||
warnings.warn('no implementation for all non-positive values')
|
warnings.warn('no implementation for all non-positive values')
|
||||||
cmap_ = None
|
cmap_col = None
|
||||||
else:
|
else:
|
||||||
cmap_ = 'RdBu_r'
|
cmap_col = 'RdBu_r'
|
||||||
else:
|
|
||||||
cmap_ = cmap
|
|
||||||
|
|
||||||
if norm is None:
|
if norm_col is None:
|
||||||
l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5])
|
l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5])
|
||||||
w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2
|
w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2
|
||||||
|
|
||||||
if all_non_neg:
|
if all_non_neg:
|
||||||
if h1 > 0.1 * h2:
|
if h1 > 0.1 * h2:
|
||||||
norm_ = Normalize(vmin=0, vmax=quantize(2 * h2))
|
norm_col = Normalize(vmin=0, vmax=quantize(2 * h2))
|
||||||
else:
|
else:
|
||||||
norm_ = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2))
|
norm_col = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2))
|
||||||
elif all_non_pos:
|
elif all_non_pos:
|
||||||
warnings.warn('no implementation for all non-positive values')
|
warnings.warn('no implementation for all non-positive values')
|
||||||
norm_ = None
|
norm_col = None
|
||||||
else:
|
else:
|
||||||
if w1 > 0.1 * w2:
|
if w1 > 0.1 * w2:
|
||||||
vlim = quantize(2.5 * w1)
|
vlim = quantize(2.5 * w1)
|
||||||
norm_ = Normalize(vmin=-vlim, vmax=vlim)
|
norm_col = Normalize(vmin=-vlim, vmax=vlim)
|
||||||
else:
|
else:
|
||||||
vlim = quantize(w2)
|
vlim = quantize(w2)
|
||||||
norm_ = SymLogNorm(linthresh=0.1 * w1, vmin=-vlim, vmax=vlim)
|
norm_col = SymLogNorm(linthresh=0.1 * w1, vmin=-vlim, vmax=vlim)
|
||||||
else:
|
|
||||||
norm_ = norm
|
|
||||||
|
|
||||||
for c in range(field.shape[0]):
|
for c in range(field.shape[0]):
|
||||||
s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2
|
s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2
|
||||||
axes[c, f].pcolormesh(field[s], cmap=cmap_, norm=norm_)
|
axes[c, f].pcolormesh(field[s], cmap=cmap_col, norm=norm_col)
|
||||||
|
|
||||||
axes[c, f].set_aspect('equal')
|
axes[c, f].set_aspect('equal')
|
||||||
|
|
||||||
@ -92,15 +89,11 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
|||||||
axes[c, f].axis('off')
|
axes[c, f].axis('off')
|
||||||
|
|
||||||
fig.colorbar(
|
fig.colorbar(
|
||||||
ScalarMappable(norm=norm_, cmap=cmap_),
|
ScalarMappable(norm=norm_col, cmap=cmap_col),
|
||||||
ax=axes[:, f],
|
cax=axes[-1, f],
|
||||||
orientation='horizontal',
|
orientation='horizontal',
|
||||||
fraction=cbar_frac,
|
|
||||||
pad=0,
|
|
||||||
shrink=0.9,
|
|
||||||
aspect=10,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
fig.set_constrained_layout_pads(w_pad=2/72, h_pad=2/72, wspace=0, hspace=0)
|
fig.tight_layout()
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
@ -330,12 +330,14 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
logger.add_scalar('loss/epoch/train/lxe', epoch_loss.prod(),
|
logger.add_scalar('loss/epoch/train/lxe', epoch_loss.prod(),
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
|
||||||
logger.add_figure('fig/epoch/train', plt_slices(
|
fig = plt_slices(
|
||||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||||
), global_step=epoch+1)
|
)
|
||||||
|
logger.add_figure('fig/epoch/train', fig, global_step=epoch+1)
|
||||||
|
fig.clf()
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
@ -380,12 +382,14 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
|||||||
logger.add_scalar('loss/epoch/val/lxe', epoch_loss.prod(),
|
logger.add_scalar('loss/epoch/val/lxe', epoch_loss.prod(),
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
|
|
||||||
logger.add_figure('fig/epoch/val', plt_slices(
|
fig = plt_slices(
|
||||||
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
input[-1], lag_out[-1], lag_tgt[-1], lag_out[-1] - lag_tgt[-1],
|
||||||
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
eul_out[-1], eul_tgt[-1], eul_out[-1] - eul_tgt[-1],
|
||||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||||
), global_step=epoch+1)
|
)
|
||||||
|
logger.add_figure('fig/epoch/val', fig, global_step=epoch+1)
|
||||||
|
fig.clf()
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user