Add slice plotting for field >2d
This commit is contained in:
parent
5c4d244a54
commit
819e77cd86
@ -9,14 +9,18 @@ from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
|||||||
from matplotlib.cm import ScalarMappable
|
from matplotlib.cm import ScalarMappable
|
||||||
|
|
||||||
|
|
||||||
def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
|
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
||||||
|
"""Plot slices of fields of more than 2 spatial dimensions.
|
||||||
|
"""
|
||||||
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
||||||
else field for field in fields]
|
else field for field in fields]
|
||||||
|
|
||||||
assert all(isinstance(field, np.ndarray) for field in fields)
|
assert all(isinstance(field, np.ndarray) for field in fields)
|
||||||
|
assert all(field.ndim == fields[0].ndim for field in fields)
|
||||||
|
|
||||||
nc = max(field.shape[0] for field in fields)
|
nc = max(field.shape[0] for field in fields)
|
||||||
nf = len(fields)
|
nf = len(fields)
|
||||||
|
nd = fields[0].ndim - 1
|
||||||
|
|
||||||
if title is not None:
|
if title is not None:
|
||||||
assert len(title) == nf
|
assert len(title) == nf
|
||||||
@ -73,8 +77,8 @@ def fig3d(*fields, size=64, title=None, cmap=None, norm=None):
|
|||||||
norm_ = norm
|
norm_ = norm
|
||||||
|
|
||||||
for c in range(field.shape[0]):
|
for c in range(field.shape[0]):
|
||||||
axes[c, f].pcolormesh(field[c, 0, :size, :size],
|
s = (c,) + (0,) * (nd - 2) + (slice(64),) * 2
|
||||||
cmap=cmap_, norm=norm_)
|
axes[c, f].pcolormesh(field[s], cmap=cmap_, norm=norm_)
|
||||||
|
|
||||||
axes[c, f].set_aspect('equal')
|
axes[c, f].set_aspect('equal')
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from .data import FieldDataset, GroupedRandomSampler
|
from .data import FieldDataset, GroupedRandomSampler
|
||||||
from .data.figures import fig3d
|
from .data.figures import plt_slices
|
||||||
from . import models
|
from . import models
|
||||||
from .models import (narrow_like,
|
from .models import (narrow_like,
|
||||||
adv_model_wrapper, adv_criterion_wrapper,
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
@ -436,7 +436,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
skip_chan = 0
|
skip_chan = 0
|
||||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||||
skip_chan = sum(args.in_chan)
|
skip_chan = sum(args.in_chan)
|
||||||
logger.add_figure('fig/epoch/train', fig3d(
|
logger.add_figure('fig/epoch/train', plt_slices(
|
||||||
input[-1],
|
input[-1],
|
||||||
output[-1, skip_chan:],
|
output[-1, skip_chan:],
|
||||||
target[-1, skip_chan:],
|
target[-1, skip_chan:],
|
||||||
@ -511,7 +511,7 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
skip_chan = 0
|
skip_chan = 0
|
||||||
if args.adv and epoch >= args.adv_start and args.cgan:
|
if args.adv and epoch >= args.adv_start and args.cgan:
|
||||||
skip_chan = sum(args.in_chan)
|
skip_chan = sum(args.in_chan)
|
||||||
logger.add_figure('fig/epoch/val', fig3d(
|
logger.add_figure('fig/epoch/val', plt_slices(
|
||||||
input[-1],
|
input[-1],
|
||||||
output[-1, skip_chan:],
|
output[-1, skip_chan:],
|
||||||
target[-1, skip_chan:],
|
target[-1, skip_chan:],
|
||||||
|
Loading…
Reference in New Issue
Block a user