Add per field colormap and Fix scope bug

This commit is contained in:
Yin Li 2020-02-05 11:58:05 -05:00
parent a22fb64d12
commit b609797e27
2 changed files with 51 additions and 37 deletions

View File

@ -17,41 +17,51 @@ def fig3d(*fields, size=64, cmap=None, norm=None):
nc = fields[-1].shape[0]
nf = len(fields)
fig, axes = plt.subplots(nc, nf, squeeze=False, figsize=(5 * nf, 4.25 * nc))
colorbar_frac = 0.15 / (0.85 * nc + 0.15)
fig, axes = plt.subplots(nc, nf, squeeze=False, figsize=(4 * nf, 4 * nc * (1 + colorbar_frac)))
if cmap is None:
if (fields[-1] >= 0).all():
cmap = 'viridis'
elif (fields[-1] <= 0).all():
raise NotImplementedError
else:
cmap = 'RdBu_r'
def quantize(x):
return 2 ** round(log2(x), ndigits=1)
if norm is None:
def quantize(x):
return 2 ** round(log2(x), ndigits=1)
for f in range(nf):
all_non_neg = (fields[f] >= 0).all()
all_non_pos = (fields[f] <= 0).all()
l2, l1, h1, h2 = np.percentile(fields[-1], [2.5, 16, 84, 97.5])
w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2
if (fields[-1] >= 0).all():
if h1 > 0.1 * h2:
norm = Normalize(vmin=0, vmax=quantize(2 * h2))
if cmap is None:
if all_non_neg:
cmap_ = 'viridis'
elif all_non_pos:
raise NotImplementedError
else:
norm = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2))
elif (fields[-1] <= 0).all():
raise NotImplementedError
cmap_ = 'RdBu_r'
else:
if w1 > 0.1 * w2:
vlim = quantize(2.5 * w1)
norm = Normalize(vmin=-vlim, vmax=vlim)
else:
vlim = quantize(w2)
norm = SymLogNorm(linthresh=0.1 * w1, vmin=-vlim, vmax=vlim)
cmap_ = cmap
for c in range(nc):
for f in range(nf):
axes[c, f].imshow(fields[f][c, 0, :size, :size], cmap=cmap, norm=norm)
plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax=axes)
if norm is None:
l2, l1, h1, h2 = np.percentile(fields[f], [2.5, 16, 84, 97.5])
w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2
if all_non_neg:
if h1 > 0.1 * h2:
norm_ = Normalize(vmin=0, vmax=quantize(2 * h2))
else:
norm_ = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2))
elif all_non_pos:
raise NotImplementedError
else:
if w1 > 0.1 * w2:
vlim = quantize(2.5 * w1)
norm_ = Normalize(vmin=-vlim, vmax=vlim)
else:
vlim = quantize(w2)
norm_ = SymLogNorm(linthresh=0.1 * w1, vmin=-vlim, vmax=vlim)
else:
norm_ = norm
for c in range(nc):
axes[c, f].imshow(fields[f][c, 0, :size, :size], cmap=cmap_, norm=norm_)
plt.colorbar(ScalarMappable(norm=norm_, cmap=cmap_), ax=axes[:, f],
orientation='horizontal', fraction=colorbar_frac, pad=0.05)
return fig

View File

@ -83,10 +83,10 @@ def gpu_worker(local_rank, args):
pin_memory=True
)
in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = getattr(models, args.model)
model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan))
model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device],
process_group=dist.new_group())
@ -111,8 +111,8 @@ def gpu_worker(local_rank, args):
if args.adv:
adv_model = getattr(models, args.adv_model)
adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(in_chan + out_chan)
if args.cgan else sum(out_chan), 1)
adv_model = adv_model(sum(args.in_chan + args.out_chan)
if args.cgan else sum(args.out_chan), 1)
adv_model.to(args.device)
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device],
process_group=dist.new_group())
@ -323,8 +323,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
'real': epoch_loss[4],
}, global_step=epoch+1)
skip_chan = sum(in_chan) if args.adv and args.cgan else 0
args.logger.add_figure('fig/epoch/train',
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
args.logger.add_figure('fig/epoch/train/in',
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
args.logger.add_figure('fig/epoch/train/out',
fig3d(output[-1, skip_chan:], target[-1, skip_chan:]),
global_step =epoch+1)
@ -389,7 +391,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
'real': epoch_loss[4],
}, global_step=epoch+1)
skip_chan = sum(in_chan) if args.adv and args.cgan else 0
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
args.logger.add_figure('fig/epoch/val/in',
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
args.logger.add_figure('fig/epoch/val',
fig3d(output[-1, skip_chan:], target[-1, skip_chan:]),
global_step =epoch+1)