Add power spectrum tracking

This commit is contained in:
Yin Li 2020-08-22 23:24:25 -04:00
parent afaf4675fe
commit 3eb1b0bccc
2 changed files with 72 additions and 7 deletions

View File

@ -8,6 +8,8 @@ import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
from matplotlib.cm import ScalarMappable
from ..models import lag2eul, power
def quantize(x):
return 2 ** round(log2(x), ndigits=1)
@ -15,14 +17,15 @@ def quantize(x):
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
"""Plot slices of fields of more than 2 spatial dimensions.
Each field should have a channel dimension followed by spatial dimensions,
i.e. no batch dimension.
"""
plt.close('all')
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
else field for field in fields]
assert all(isinstance(field, torch.Tensor) for field in fields)
assert all(isinstance(field, np.ndarray) for field in fields)
assert all(field.ndim == fields[0].ndim for field in fields)
fields = [field.detach().cpu().numpy() for field in fields]
nc = max(field.shape[0] for field in fields)
nf = len(fields)
@ -110,3 +113,47 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
fig.tight_layout()
return fig
def plt_power(*fields, l2e=False, label=None):
"""Plot power spectra of fields.
Each field should have batch and channel dimensions followed by spatial
dimensions.
Optionally the field can be transformed by lag2eul first.
See `map2map.models.power`.
"""
plt.close('all')
if label is not None:
assert len(label) == len(fields)
else:
label = [None] * len(fields)
with torch.no_grad():
if l2e:
fields = lag2eul(*fields)
ks, Ps = [], []
for field in fields:
k, P, _ = power(field)
ks.append(k)
Ps.append(P)
ks = [k.cpu().numpy() for k in ks]
Ps = [P.cpu().numpy() for P in Ps]
fig, axes = plt.subplots(figsize=(4.8, 3.6), dpi=150)
for k, P, l in zip(ks, Ps, label):
axes.loglog(k, P, label=l, alpha=0.7)
axes.legend()
axes.set_xlabel('unnormalized wavenumber')
axes.set_ylabel('unnormalized power')
fig.tight_layout()
return fig

View File

@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, DistFieldSampler
from .data.figures import plt_slices
from .data.figures import plt_slices, plt_power
from . import models
from .models import narrow_cast, resample, Lag2Eul
from .utils import import_attr, load_model_state_dict
@ -306,9 +306,18 @@ def train(epoch, loader, model, lag2eul, criterion,
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
)
logger.add_figure('fig/epoch/train', fig, global_step=epoch+1)
logger.add_figure('fig/train', fig, global_step=epoch+1)
fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
#fig.clf()
return epoch_loss
@ -358,9 +367,18 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
)
logger.add_figure('fig/epoch/val', fig, global_step=epoch+1)
logger.add_figure('fig/val', fig, global_step=epoch+1)
fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
#fig.clf()
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
# label=['in', 'out', 'tgt'])
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
#fig.clf()
return epoch_loss