From 819e77cd8674ac49f02aea5c74e4d28d3c26b34e Mon Sep 17 00:00:00 2001 From: Yin Li Date: Fri, 10 Jul 2020 14:58:08 -0400 Subject: [PATCH] Add slice plotting for field >2d --- map2map/data/figures.py | 10 +++++++--- map2map/train.py | 6 +++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/map2map/data/figures.py b/map2map/data/figures.py index 02ebda3..86eedde 100644 --- a/map2map/data/figures.py +++ b/map2map/data/figures.py @@ -9,14 +9,18 @@ from matplotlib.colors import Normalize, LogNorm, SymLogNorm from matplotlib.cm import ScalarMappable -def fig3d(*fields, size=64, title=None, cmap=None, norm=None): +def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): + """Plot slices of fields of more than 2 spatial dimensions. + """ fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor) else field for field in fields] assert all(isinstance(field, np.ndarray) for field in fields) + assert all(field.ndim == fields[0].ndim for field in fields) nc = max(field.shape[0] for field in fields) nf = len(fields) + nd = fields[0].ndim - 1 if title is not None: assert len(title) == nf @@ -73,8 +77,8 @@ def fig3d(*fields, size=64, title=None, cmap=None, norm=None): norm_ = norm for c in range(field.shape[0]): - axes[c, f].pcolormesh(field[c, 0, :size, :size], - cmap=cmap_, norm=norm_) + s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2 + axes[c, f].pcolormesh(field[s], cmap=cmap_, norm=norm_) axes[c, f].set_aspect('equal') diff --git a/map2map/train.py b/map2map/train.py index 3d44fdf..ec9345e 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -15,7 +15,7 @@ from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset, GroupedRandomSampler -from .data.figures import fig3d +from .data.figures import plt_slices from . import models from .models import (narrow_like, adv_model_wrapper, adv_criterion_wrapper, @@ -436,7 +436,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, skip_chan = 0 if args.adv and epoch >= args.adv_start and args.cgan: skip_chan = sum(args.in_chan) - logger.add_figure('fig/epoch/train', fig3d( + logger.add_figure('fig/epoch/train', plt_slices( input[-1], output[-1, skip_chan:], target[-1, skip_chan:], @@ -511,7 +511,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, skip_chan = 0 if args.adv and epoch >= args.adv_start and args.cgan: skip_chan = sum(args.in_chan) - logger.add_figure('fig/epoch/val', fig3d( + logger.add_figure('fig/epoch/val', plt_slices( input[-1], output[-1, skip_chan:], target[-1, skip_chan:],