Add power spectrum tracking

This commit is contained in:
Yin Li 2020-08-22 23:24:25 -04:00
parent afaf4675fe
commit 3eb1b0bccc
2 changed files with 72 additions and 7 deletions

View File

@ -8,6 +8,8 @@ import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm, SymLogNorm from matplotlib.colors import Normalize, LogNorm, SymLogNorm
from matplotlib.cm import ScalarMappable from matplotlib.cm import ScalarMappable
from ..models import lag2eul, power
def quantize(x): def quantize(x):
return 2 ** round(log2(x), ndigits=1) return 2 ** round(log2(x), ndigits=1)
@ -15,14 +17,15 @@ def quantize(x):
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
"""Plot slices of fields of more than 2 spatial dimensions. """Plot slices of fields of more than 2 spatial dimensions.
Each field should have a channel dimension followed by spatial dimensions,
i.e. no batch dimension.
""" """
plt.close('all') plt.close('all')
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor) assert all(isinstance(field, torch.Tensor) for field in fields)
else field for field in fields]
assert all(isinstance(field, np.ndarray) for field in fields) fields = [field.detach().cpu().numpy() for field in fields]
assert all(field.ndim == fields[0].ndim for field in fields)
nc = max(field.shape[0] for field in fields) nc = max(field.shape[0] for field in fields)
nf = len(fields) nf = len(fields)
@ -110,3 +113,47 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
fig.tight_layout() fig.tight_layout()
return fig return fig
def plt_power(*fields, l2e=False, label=None):
"""Plot power spectra of fields.
Each field should have batch and channel dimensions followed by spatial
dimensions.
Optionally the field can be transformed by lag2eul first.
See `map2map.models.power`.
"""
plt.close('all')
if label is not None:
assert len(label) == len(fields)
else:
label = [None] * len(fields)
with torch.no_grad():
if l2e:
fields = lag2eul(*fields)
ks, Ps = [], []
for field in fields:
k, P, _ = power(field)
ks.append(k)
Ps.append(P)
ks = [k.cpu().numpy() for k in ks]
Ps = [P.cpu().numpy() for P in Ps]
fig, axes = plt.subplots(figsize=(4.8, 3.6), dpi=150)
for k, P, l in zip(ks, Ps, label):
axes.loglog(k, P, label=l, alpha=0.7)
axes.legend()
axes.set_xlabel('unnormalized wavenumber')
axes.set_ylabel('unnormalized power')
fig.tight_layout()
return fig

View File

@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, DistFieldSampler from .data import FieldDataset, DistFieldSampler
from .data.figures import plt_slices from .data.figures import plt_slices, plt_power
from . import models from . import models
from .models import narrow_cast, resample, Lag2Eul from .models import narrow_cast, resample, Lag2Eul
from .utils import import_attr, load_model_state_dict from .utils import import_attr, load_model_state_dict
@ -306,9 +306,18 @@ def train(epoch, loader, model, lag2eul, criterion,
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
) )
logger.add_figure('fig/epoch/train', fig, global_step=epoch+1) logger.add_figure('fig/train', fig, global_step=epoch+1)
fig.clf() fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf()
return epoch_loss return epoch_loss
@ -358,9 +367,18 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt', title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'], 'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
) )
logger.add_figure('fig/epoch/val', fig, global_step=epoch+1) logger.add_figure('fig/val', fig, global_step=epoch+1)
fig.clf() fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf()
return epoch_loss return epoch_loss