Add figures with tensorboard
This commit is contained in:
parent
7f6578c63e
commit
db69e9f953
57
map2map/data/figures.py
Normal file
57
map2map/data/figures.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from math import log2, log10, ceil
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.colors import Normalize, LogNorm, SymLogNorm
|
||||||
|
from matplotlib.cm import ScalarMappable
|
||||||
|
|
||||||
|
|
||||||
|
def fig3d(*fields, size=64, cmap=None, norm=None):
|
||||||
|
fields = [f.detach().cpu().numpy() if isinstance(f, torch.Tensor) else f
|
||||||
|
for f in fields]
|
||||||
|
|
||||||
|
assert all(isinstance(f, np.ndarray) for f in fields)
|
||||||
|
|
||||||
|
nc = fields[-1].shape[0]
|
||||||
|
nf = len(fields)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(nc, nf, squeeze=False, figsize=(5 * nf, 4.25 * nc))
|
||||||
|
|
||||||
|
if cmap is None:
|
||||||
|
if (fields[-1] >= 0).all():
|
||||||
|
cmap = 'viridis'
|
||||||
|
elif (fields[-1] <= 0).all():
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
cmap = 'RdBu_r'
|
||||||
|
|
||||||
|
if norm is None:
|
||||||
|
def quantize(x):
|
||||||
|
return 2 ** round(log2(x), ndigits=1)
|
||||||
|
|
||||||
|
l2, l1, h1, h2 = np.percentile(fields[-1], [2.5, 16, 84, 97.5])
|
||||||
|
w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2
|
||||||
|
|
||||||
|
if (fields[-1] >= 0).all():
|
||||||
|
if h1 > 0.1 * h2:
|
||||||
|
norm = Normalize(vmin=0, vmax=quantize(2 * h2))
|
||||||
|
else:
|
||||||
|
norm = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2))
|
||||||
|
elif (fields[-1] <= 0).all():
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
if w1 > 0.1 * w2:
|
||||||
|
vlim = quantize(2.5 * w1)
|
||||||
|
norm = Normalize(vmin=-vlim, vmax=vlim)
|
||||||
|
else:
|
||||||
|
vlim = quantize(w2)
|
||||||
|
norm = SymLogNorm(linthresh=0.1 * w1, vmin=-vlim, vmax=vlim)
|
||||||
|
|
||||||
|
for c in range(nc):
|
||||||
|
for f in range(nf):
|
||||||
|
axes[c, f].imshow(fields[f][c, 0, :size, :size], cmap=cmap, norm=norm)
|
||||||
|
plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax=axes)
|
||||||
|
|
||||||
|
return fig
|
@ -10,6 +10,7 @@ from torch.utils.data import DataLoader
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
|
from .data.figures import fig3d
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_like
|
from .models import narrow_like
|
||||||
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||||
@ -322,6 +323,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
'real': epoch_loss[4],
|
'real': epoch_loss[4],
|
||||||
}, global_step=epoch+1)
|
}, global_step=epoch+1)
|
||||||
|
|
||||||
|
skip_chan = sum(in_chan) if args.adv and args.cgan else 0
|
||||||
|
args.logger.add_figure('fig/epoch/train',
|
||||||
|
fig3d(output[-1, skip_chan:], target[-1, skip_chan:]),
|
||||||
|
global_step =epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
@ -383,4 +389,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
|||||||
'real': epoch_loss[4],
|
'real': epoch_loss[4],
|
||||||
}, global_step=epoch+1)
|
}, global_step=epoch+1)
|
||||||
|
|
||||||
|
skip_chan = sum(in_chan) if args.adv and args.cgan else 0
|
||||||
|
args.logger.add_figure('fig/epoch/val',
|
||||||
|
fig3d(output[-1, skip_chan:], target[-1, skip_chan:]),
|
||||||
|
global_step =epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
Loading…
Reference in New Issue
Block a user