Add power spectrum tracking
This commit is contained in:
parent
afaf4675fe
commit
3eb1b0bccc
2 changed files with 72 additions and 7 deletions
|
@ -8,6 +8,8 @@ import matplotlib.pyplot as plt
|
|||
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
||||
from matplotlib.cm import ScalarMappable
|
||||
|
||||
from ..models import lag2eul, power
|
||||
|
||||
|
||||
def quantize(x):
|
||||
return 2 ** round(log2(x), ndigits=1)
|
||||
|
@ -15,14 +17,15 @@ def quantize(x):
|
|||
|
||||
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
||||
"""Plot slices of fields of more than 2 spatial dimensions.
|
||||
|
||||
Each field should have a channel dimension followed by spatial dimensions,
|
||||
i.e. no batch dimension.
|
||||
"""
|
||||
plt.close('all')
|
||||
|
||||
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
||||
else field for field in fields]
|
||||
assert all(isinstance(field, torch.Tensor) for field in fields)
|
||||
|
||||
assert all(isinstance(field, np.ndarray) for field in fields)
|
||||
assert all(field.ndim == fields[0].ndim for field in fields)
|
||||
fields = [field.detach().cpu().numpy() for field in fields]
|
||||
|
||||
nc = max(field.shape[0] for field in fields)
|
||||
nf = len(fields)
|
||||
|
@ -110,3 +113,47 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
|||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plt_power(*fields, l2e=False, label=None):
|
||||
"""Plot power spectra of fields.
|
||||
|
||||
Each field should have batch and channel dimensions followed by spatial
|
||||
dimensions.
|
||||
|
||||
Optionally the field can be transformed by lag2eul first.
|
||||
|
||||
See `map2map.models.power`.
|
||||
"""
|
||||
plt.close('all')
|
||||
|
||||
if label is not None:
|
||||
assert len(label) == len(fields)
|
||||
else:
|
||||
label = [None] * len(fields)
|
||||
|
||||
with torch.no_grad():
|
||||
if l2e:
|
||||
fields = lag2eul(*fields)
|
||||
|
||||
ks, Ps = [], []
|
||||
for field in fields:
|
||||
k, P, _ = power(field)
|
||||
ks.append(k)
|
||||
Ps.append(P)
|
||||
|
||||
ks = [k.cpu().numpy() for k in ks]
|
||||
Ps = [P.cpu().numpy() for P in Ps]
|
||||
|
||||
fig, axes = plt.subplots(figsize=(4.8, 3.6), dpi=150)
|
||||
|
||||
for k, P, l in zip(ks, Ps, label):
|
||||
axes.loglog(k, P, label=l, alpha=0.7)
|
||||
|
||||
axes.legend()
|
||||
axes.set_xlabel('unnormalized wavenumber')
|
||||
axes.set_ylabel('unnormalized power')
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
|
|
|
@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .data import FieldDataset, DistFieldSampler
|
||||
from .data.figures import plt_slices
|
||||
from .data.figures import plt_slices, plt_power
|
||||
from . import models
|
||||
from .models import narrow_cast, resample, Lag2Eul
|
||||
from .utils import import_attr, load_model_state_dict
|
||||
|
@ -306,9 +306,18 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||
)
|
||||
logger.add_figure('fig/epoch/train', fig, global_step=epoch+1)
|
||||
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
||||
fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
|
||||
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||
# label=['in', 'out', 'tgt'])
|
||||
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
return epoch_loss
|
||||
|
||||
|
||||
|
@ -358,9 +367,18 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
|||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||
)
|
||||
logger.add_figure('fig/epoch/val', fig, global_step=epoch+1)
|
||||
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
||||
fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
|
||||
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||
# label=['in', 'out', 'tgt'])
|
||||
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||
#fig.clf()
|
||||
|
||||
return epoch_loss
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue