Add power spectrum tracking
This commit is contained in:
parent
afaf4675fe
commit
3eb1b0bccc
@ -8,6 +8,8 @@ import matplotlib.pyplot as plt
|
|||||||
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
||||||
from matplotlib.cm import ScalarMappable
|
from matplotlib.cm import ScalarMappable
|
||||||
|
|
||||||
|
from ..models import lag2eul, power
|
||||||
|
|
||||||
|
|
||||||
def quantize(x):
|
def quantize(x):
|
||||||
return 2 ** round(log2(x), ndigits=1)
|
return 2 ** round(log2(x), ndigits=1)
|
||||||
@ -15,14 +17,15 @@ def quantize(x):
|
|||||||
|
|
||||||
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
||||||
"""Plot slices of fields of more than 2 spatial dimensions.
|
"""Plot slices of fields of more than 2 spatial dimensions.
|
||||||
|
|
||||||
|
Each field should have a channel dimension followed by spatial dimensions,
|
||||||
|
i.e. no batch dimension.
|
||||||
"""
|
"""
|
||||||
plt.close('all')
|
plt.close('all')
|
||||||
|
|
||||||
fields = [field.detach().cpu().numpy() if isinstance(field, torch.Tensor)
|
assert all(isinstance(field, torch.Tensor) for field in fields)
|
||||||
else field for field in fields]
|
|
||||||
|
|
||||||
assert all(isinstance(field, np.ndarray) for field in fields)
|
fields = [field.detach().cpu().numpy() for field in fields]
|
||||||
assert all(field.ndim == fields[0].ndim for field in fields)
|
|
||||||
|
|
||||||
nc = max(field.shape[0] for field in fields)
|
nc = max(field.shape[0] for field in fields)
|
||||||
nf = len(fields)
|
nf = len(fields)
|
||||||
@ -110,3 +113,47 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None):
|
|||||||
fig.tight_layout()
|
fig.tight_layout()
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plt_power(*fields, l2e=False, label=None):
|
||||||
|
"""Plot power spectra of fields.
|
||||||
|
|
||||||
|
Each field should have batch and channel dimensions followed by spatial
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
Optionally the field can be transformed by lag2eul first.
|
||||||
|
|
||||||
|
See `map2map.models.power`.
|
||||||
|
"""
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
if label is not None:
|
||||||
|
assert len(label) == len(fields)
|
||||||
|
else:
|
||||||
|
label = [None] * len(fields)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if l2e:
|
||||||
|
fields = lag2eul(*fields)
|
||||||
|
|
||||||
|
ks, Ps = [], []
|
||||||
|
for field in fields:
|
||||||
|
k, P, _ = power(field)
|
||||||
|
ks.append(k)
|
||||||
|
Ps.append(P)
|
||||||
|
|
||||||
|
ks = [k.cpu().numpy() for k in ks]
|
||||||
|
Ps = [P.cpu().numpy() for P in Ps]
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(figsize=(4.8, 3.6), dpi=150)
|
||||||
|
|
||||||
|
for k, P, l in zip(ks, Ps, label):
|
||||||
|
axes.loglog(k, P, label=l, alpha=0.7)
|
||||||
|
|
||||||
|
axes.legend()
|
||||||
|
axes.set_xlabel('unnormalized wavenumber')
|
||||||
|
axes.set_ylabel('unnormalized power')
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
@ -14,7 +14,7 @@ from torch.utils.data import DataLoader
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from .data import FieldDataset, DistFieldSampler
|
from .data import FieldDataset, DistFieldSampler
|
||||||
from .data.figures import plt_slices
|
from .data.figures import plt_slices, plt_power
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_cast, resample, Lag2Eul
|
from .models import narrow_cast, resample, Lag2Eul
|
||||||
from .utils import import_attr, load_model_state_dict
|
from .utils import import_attr, load_model_state_dict
|
||||||
@ -306,9 +306,18 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/epoch/train', fig, global_step=epoch+1)
|
logger.add_figure('fig/train', fig, global_step=epoch+1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
|
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
|
||||||
|
#logger.add_figure('fig/train/power/lag', fig, global_step=epoch+1)
|
||||||
|
#fig.clf()
|
||||||
|
|
||||||
|
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||||
|
# label=['in', 'out', 'tgt'])
|
||||||
|
#logger.add_figure('fig/train/power/eul', fig, global_step=epoch+1)
|
||||||
|
#fig.clf()
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
@ -358,9 +367,18 @@ def validate(epoch, loader, model, lag2eul, criterion, logger, device, args):
|
|||||||
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
title=['in', 'lag_out', 'lag_tgt', 'lag_out - lag_tgt',
|
||||||
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
'eul_out', 'eul_tgt', 'eul_out - eul_tgt'],
|
||||||
)
|
)
|
||||||
logger.add_figure('fig/epoch/val', fig, global_step=epoch+1)
|
logger.add_figure('fig/val', fig, global_step=epoch+1)
|
||||||
fig.clf()
|
fig.clf()
|
||||||
|
|
||||||
|
#fig = plt_power(input, lag_out, lag_tgt, label=['in', 'out', 'tgt'])
|
||||||
|
#logger.add_figure('fig/val/power/lag', fig, global_step=epoch+1)
|
||||||
|
#fig.clf()
|
||||||
|
|
||||||
|
#fig = plt_power(input, lag_out, lag_tgt, l2e=True,
|
||||||
|
# label=['in', 'out', 'tgt'])
|
||||||
|
#logger.add_figure('fig/val/power/eul', fig, global_step=epoch+1)
|
||||||
|
#fig.clf()
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user