Combine input and output figures
This commit is contained in:
parent
a46746287a
commit
75b1c19dcd
@ -1,4 +1,5 @@
|
|||||||
from math import log2, log10, ceil
|
from math import log2, log10, ceil
|
||||||
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib
|
import matplotlib
|
||||||
@ -9,36 +10,38 @@ from matplotlib.cm import ScalarMappable
|
|||||||
|
|
||||||
|
|
||||||
def fig3d(*fields, size=64, cmap=None, norm=None):
|
def fig3d(*fields, size=64, cmap=None, norm=None):
|
||||||
fields = [f.detach().cpu().numpy() if isinstance(f, torch.Tensor) else f
|
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
||||||
for f in fields]
|
else field for field in fields]
|
||||||
|
|
||||||
assert all(isinstance(f, np.ndarray) for f in fields)
|
assert all(isinstance(field, np.ndarray) for field in fields)
|
||||||
|
|
||||||
nc = fields[-1].shape[0]
|
nc = max(field.shape[0] for field in fields)
|
||||||
nf = len(fields)
|
nf = len(fields)
|
||||||
|
|
||||||
colorbar_frac = 0.15 / (0.85 * nc + 0.15)
|
colorbar_frac = 0.15 / (0.85 * nc + 0.15)
|
||||||
fig, axes = plt.subplots(nc, nf, squeeze=False, figsize=(4 * nf, 4 * nc * (1 + colorbar_frac)))
|
fig, axes = plt.subplots(nc, nf, squeeze=False,
|
||||||
|
figsize=(4 * nf, 4 * nc * (1 + colorbar_frac)))
|
||||||
|
|
||||||
def quantize(x):
|
def quantize(x):
|
||||||
return 2 ** round(log2(x), ndigits=1)
|
return 2 ** round(log2(x), ndigits=1)
|
||||||
|
|
||||||
for f in range(nf):
|
for f, field in enumerate(fields):
|
||||||
all_non_neg = (fields[f] >= 0).all()
|
all_non_neg = (field >= 0).all()
|
||||||
all_non_pos = (fields[f] <= 0).all()
|
all_non_pos = (field <= 0).all()
|
||||||
|
|
||||||
if cmap is None:
|
if cmap is None:
|
||||||
if all_non_neg:
|
if all_non_neg:
|
||||||
cmap_ = 'viridis'
|
cmap_ = 'viridis'
|
||||||
elif all_non_pos:
|
elif all_non_pos:
|
||||||
raise NotImplementedError
|
warnings.warn('no implementation for all non-positive values')
|
||||||
|
cmap_ = None
|
||||||
else:
|
else:
|
||||||
cmap_ = 'RdBu_r'
|
cmap_ = 'RdBu_r'
|
||||||
else:
|
else:
|
||||||
cmap_ = cmap
|
cmap_ = cmap
|
||||||
|
|
||||||
if norm is None:
|
if norm is None:
|
||||||
l2, l1, h1, h2 = np.percentile(fields[f], [2.5, 16, 84, 97.5])
|
l2, l1, h1, h2 = np.percentile(field, [2.5, 16, 84, 97.5])
|
||||||
w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2
|
w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2
|
||||||
|
|
||||||
if all_non_neg:
|
if all_non_neg:
|
||||||
@ -47,7 +50,8 @@ def fig3d(*fields, size=64, cmap=None, norm=None):
|
|||||||
else:
|
else:
|
||||||
norm_ = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2))
|
norm_ = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2))
|
||||||
elif all_non_pos:
|
elif all_non_pos:
|
||||||
raise NotImplementedError
|
warnings.warn('no implementation for all non-positive values')
|
||||||
|
norm_ = None
|
||||||
else:
|
else:
|
||||||
if w1 > 0.1 * w2:
|
if w1 > 0.1 * w2:
|
||||||
vlim = quantize(2.5 * w1)
|
vlim = quantize(2.5 * w1)
|
||||||
@ -58,8 +62,10 @@ def fig3d(*fields, size=64, cmap=None, norm=None):
|
|||||||
else:
|
else:
|
||||||
norm_ = norm
|
norm_ = norm
|
||||||
|
|
||||||
for c in range(nc):
|
for c in range(field.shape[0]):
|
||||||
axes[c, f].imshow(fields[f][c, 0, :size, :size], cmap=cmap_, norm=norm_)
|
axes[c, f].imshow(field[c, 0, :size, :size], cmap=cmap_, norm=norm_)
|
||||||
|
for c in range(field.shape[0], nc):
|
||||||
|
axes[c, f].axis('off')
|
||||||
|
|
||||||
plt.colorbar(ScalarMappable(norm=norm_, cmap=cmap_), ax=axes[:, f],
|
plt.colorbar(ScalarMappable(norm=norm_, cmap=cmap_), ax=axes[:, f],
|
||||||
orientation='horizontal', fraction=colorbar_frac, pad=0.05)
|
orientation='horizontal', fraction=colorbar_frac, pad=0.05)
|
||||||
|
@ -357,12 +357,12 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
skip_chan = 0
|
skip_chan = 0
|
||||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||||
skip_chan = sum(args.in_chan)
|
skip_chan = sum(args.in_chan)
|
||||||
logger.add_figure('fig/epoch/train/in', fig3d(input[-1]),
|
logger.add_figure('fig/epoch/train', fig3d(
|
||||||
global_step =epoch+1)
|
input[-1],
|
||||||
logger.add_figure('fig/epoch/train/out',
|
output[-1, skip_chan:],
|
||||||
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
global_step =epoch+1)
|
), global_step=epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
@ -433,12 +433,12 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
skip_chan = 0
|
skip_chan = 0
|
||||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||||
skip_chan = sum(args.in_chan)
|
skip_chan = sum(args.in_chan)
|
||||||
logger.add_figure('fig/epoch/val/in', fig3d(input[-1]),
|
logger.add_figure('fig/epoch/val', fig3d(
|
||||||
global_step =epoch+1)
|
input[-1],
|
||||||
logger.add_figure('fig/epoch/val/out',
|
output[-1, skip_chan:],
|
||||||
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
output[-1, skip_chan:] - target[-1, skip_chan:],
|
||||||
global_step =epoch+1)
|
), global_step=epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user