Add slice plotting for field >2d
This commit is contained in:
parent
5c4d244a54
commit
819e77cd86
2 changed files with 10 additions and 6 deletions
|
@ -9,14 +9,18 @@ from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
|||
from matplotlib.cm import ScalarMappable
|
||||
|
||||
|
||||
def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
|
||||
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
||||
"""Plot slices of fields of more than 2 spatial dimensions.
|
||||
"""
|
||||
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
||||
else field for field in fields]
|
||||
|
||||
assert all(isinstance(field, np.ndarray) for field in fields)
|
||||
assert all(field.ndim == fields[0].ndim for field in fields)
|
||||
|
||||
nc = max(field.shape[0] for field in fields)
|
||||
nf = len(fields)
|
||||
nd = fields[0].ndim - 1
|
||||
|
||||
if title is not None:
|
||||
assert len(title) == nf
|
||||
|
@ -73,8 +77,8 @@ def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
|
|||
norm_ = norm
|
||||
|
||||
for c in range(field.shape[0]):
|
||||
axes[c, f].pcolormesh(field[c, 0, :size, :size],
|
||||
cmap=cmap_, norm=norm_)
|
||||
s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2
|
||||
axes[c, f].pcolormesh(field[s], cmap=cmap_, norm=norm_)
|
||||
|
||||
axes[c, f].set_aspect('equal')
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .data import FieldDataset, GroupedRandomSampler
|
||||
from .data.figures import fig3d
|
||||
from .data.figures import plt_slices
|
||||
from . import models
|
||||
from .models import (narrow_like,
|
||||
adv_model_wrapper, adv_criterion_wrapper,
|
||||
|
@ -436,7 +436,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
skip_chan = 0
|
||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||
skip_chan = sum(args.in_chan)
|
||||
logger.add_figure('fig/epoch/train', fig3d(
|
||||
logger.add_figure('fig/epoch/train', plt_slices(
|
||||
input[-1],
|
||||
output[-1, skip_chan:],
|
||||
target[-1, skip_chan:],
|
||||
|
@ -511,7 +511,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||
skip_chan = 0
|
||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||
skip_chan = sum(args.in_chan)
|
||||
logger.add_figure('fig/epoch/val', fig3d(
|
||||
logger.add_figure('fig/epoch/val', plt_slices(
|
||||
input[-1],
|
||||
output[-1, skip_chan:],
|
||||
target[-1, skip_chan:],
|
||||
|
|
Loading…
Reference in a new issue