From db69e9f953c8f1297214cd15bb07aa0b37c7b877 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Mon, 3 Feb 2020 22:18:08 -0500 Subject: [PATCH] Add figures with tensorboard --- map2map/data/figures.py | 57 +++++++++++++++++++++++++++++++++++++++++ map2map/train.py | 11 ++++++++ 2 files changed, 68 insertions(+) create mode 100644 map2map/data/figures.py diff --git a/map2map/data/figures.py b/map2map/data/figures.py new file mode 100644 index 0000000..3f8ee5f --- /dev/null +++ b/map2map/data/figures.py @@ -0,0 +1,57 @@ +from math import log2, log10, ceil +import torch +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from matplotlib.colors import Normalize, LogNorm, SymLogNorm +from matplotlib.cm import ScalarMappable + + +def fig3d(*fields, size=64, cmap=None, norm=None): + fields = [f.detach().cpu().numpy() if isinstance(f, torch.Tensor) else f + for f in fields] + + assert all(isinstance(f, np.ndarray) for f in fields) + + nc = fields[-1].shape[0] + nf = len(fields) + + fig, axes = plt.subplots(nc, nf, squeeze=False, figsize=(5 * nf, 4.25 * nc)) + + if cmap is None: + if (fields[-1] >= 0).all(): + cmap = 'viridis' + elif (fields[-1] <= 0).all(): + raise NotImplementedError + else: + cmap = 'RdBu_r' + + if norm is None: + def quantize(x): + return 2 ** round(log2(x), ndigits=1) + + l2, l1, h1, h2 = np.percentile(fields[-1], [2.5, 16, 84, 97.5]) + w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2 + + if (fields[-1] >= 0).all(): + if h1 > 0.1 * h2: + norm = Normalize(vmin=0, vmax=quantize(2 * h2)) + else: + norm = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2)) + elif (fields[-1] <= 0).all(): + raise NotImplementedError + else: + if w1 > 0.1 * w2: + vlim = quantize(2.5 * w1) + norm = Normalize(vmin=-vlim, vmax=vlim) + else: + vlim = quantize(w2) + norm = SymLogNorm(linthresh=0.1 * w1, vmin=-vlim, vmax=vlim) + + for c in range(nc): + for f in range(nf): + axes[c, f].imshow(fields[f][c, 0, :size, :size], cmap=cmap, norm=norm) + plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax=axes) + + return fig diff --git a/map2map/train.py b/map2map/train.py index ba1ae30..5d70f11 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -10,6 +10,7 @@ from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset +from .data.figures import fig3d from . import models from .models import narrow_like from .models.adversary import adv_model_wrapper, adv_criterion_wrapper @@ -322,6 +323,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, 'real': epoch_loss[4], }, global_step=epoch+1) + skip_chan = sum(in_chan) if args.adv and args.cgan else 0 + args.logger.add_figure('fig/epoch/train', + fig3d(output[-1, skip_chan:], target[-1, skip_chan:]), + global_step =epoch+1) + return epoch_loss @@ -383,4 +389,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args): 'real': epoch_loss[4], }, global_step=epoch+1) + skip_chan = sum(in_chan) if args.adv and args.cgan else 0 + args.logger.add_figure('fig/epoch/val', + fig3d(output[-1, skip_chan:], target[-1, skip_chan:]), + global_step =epoch+1) + return epoch_loss