From 3eb1b0bccc096276ed075445548a926a18040533 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sat, 22 Aug 2020 23:24:25 -0400 Subject: [PATCH] Add power spectrum tracking --- map2map/data/figures.py | 55 ++++++++++++++++++++++++++++++++++++++--- map2map/train.py | 24 +++++++++++++++--- 2 files changed, 72 insertions(+), 7 deletions(-) diff --git a/map2map/data/figures.py b/map2map/data/figures.py index 024071a..11a7bb6 100644 --- a/map2map/data/figures.py +++ b/map2map/data/figures.py @@ -8,6 +8,8 @@ import matplotlib.pyplot as plt from matplotlib.colors import Normalize, LogNorm, SymLogNorm from matplotlib.cm import ScalarMappable +from ..models import lag2eul, power + def quantize(x): return 2 ** round(log2(x), ndigits=1) @@ -15,14 +17,15 @@ def quantize(x): def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): """Plot slices of fields of more than 2 spatial dimensions. + + Each field should have a channel dimension followed by spatial dimensions, + i.e. no batch dimension. """ plt.close('all') - fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor) - else field for field in fields] + assert all(isinstance(field, torch.Tensor) for field in fields) - assert all(isinstance(field, np.ndarray) for field in fields) - assert all(field.ndim == fields[0].ndim for field in fields) + fields = [field.detach().cpu().numpy() for field in fields] nc = max(field.shape[0] for field in fields) nf = len(fields) @@ -110,3 +113,47 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): fig.tight_layout() return fig + + +def plt_power(*fields, l2e=False, label=None): + """Plot power spectra of fields. + + Each field should have batch and channel dimensions followed by spatial + dimensions. + + Optionally the field can be transformed by lag2eul first. + + See `map2map.models.power`. + """ + plt.close('all') + + if label is not None: + assert len(label) == len(fields) + else: + label = [None] * len(fields) + + with torch.no_grad(): + if l2e: + fields = lag2eul(*fields) + + ks, Ps = [], [] + for field in fields: + k, P, _ = power(field) + ks.append(k) + Ps.append(P) + + ks = [k.cpu().numpy() for k in ks] + Ps = [P.cpu().numpy() for P in Ps] + + fig, axes = plt.subplots(figsize=(4.8, 3.6), dpi=150) + + for k, P, l in zip(ks, Ps, label): + axes.loglog(k, P, label=l, alpha=0.7) + + axes.legend() + axes.set_xlabel('unnormalized wavenumber') + axes.set_ylabel('unnormalized power') + + fig.tight_layout() + + return fig diff --git a/map2map/train.py b/map2map/train.py index a7f18dd..1291d4d 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -14,7 +14,7 @@ from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset, DistFieldSampler -from .data.figures import plt_slices +from .data.figures import plt_slices, plt_power from . import models from .models import narrow_cast, resample, Lag2Eul from .utils import import_attr, load_model_state_dict @@ -306,9 +306,18 @@ def train(epoch, loader, model, lag2eul, criterion, title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], ) - logger.add_figure('fig/epoch/train', fig, global_step=epoch+1) + logger.add_figure('fig/train', fig, global_step=epoch+1) fig.clf() + #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt']) + #logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1) + #fig.clf() + + #fig = plt_power(input, lag_out, lag_tgt, l2e=True, + # label=['in', 'out', 'tgt']) + #logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1) + #fig.clf() + return epoch_loss @@ -358,9 +367,18 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args): title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], ) - logger.add_figure('fig/epoch/val', fig, global_step=epoch+1) + logger.add_figure('fig/val', fig, global_step=epoch+1) fig.clf() + #fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt']) + #logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1) + #fig.clf() + + #fig = plt_power(input, lag_out, lag_tgt, l2e=True, + # label=['in', 'out', 'tgt']) + #logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1) + #fig.clf() + return epoch_loss