Add slice plotting for field >2d

This commit is contained in:
Yin Li 2020-07-10 14:58:08 -04:00
parent 5c4d244a54
commit 819e77cd86
2 changed files with 10 additions and 6 deletions

View File

@ -9,14 +9,18 @@ from matplotlib.colors import Normalize, LogNorm, SymLogNorm
from matplotlib.cm import ScalarMappable
def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
"""Plot slices of fields of more than 2 spatial dimensions.
"""
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
else field for field in fields]
assert all(isinstance(field, np.ndarray) for field in fields)
assert all(field.ndim == fields[0].ndim for field in fields)
nc = max(field.shape[0] for field in fields)
nf = len(fields)
nd = fields[0].ndim - 1
if title is not None:
assert len(title) == nf
@ -73,8 +77,8 @@ def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
norm_ = norm
for c in range(field.shape[0]):
axes[c, f].pcolormesh(field[c, 0, :size, :size],
cmap=cmap_, norm=norm_)
s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2
axes[c, f].pcolormesh(field[s], cmap=cmap_, norm=norm_)
axes[c, f].set_aspect('equal')

View File

@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler
from .data.figures import fig3d
from .data.figures import plt_slices
from . import models
from .models import (narrow_like,
adv_model_wrapper, adv_criterion_wrapper,
@ -436,7 +436,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/train', fig3d(
logger.add_figure('fig/epoch/train', plt_slices(
input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],
@ -511,7 +511,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
skip_chan = 0
if args.adv and epoch >= args.adv_start and args.cgan:
skip_chan = sum(args.in_chan)
logger.add_figure('fig/epoch/val', fig3d(
logger.add_figure('fig/epoch/val', plt_slices(
input[-1],
output[-1, skip_chan:],
target[-1, skip_chan:],