diff --git a/map2map/data/figures.py b/map2map/data/figures.py index 3f8ee5f..6750d52 100644 --- a/map2map/data/figures.py +++ b/map2map/data/figures.py @@ -17,41 +17,51 @@ def fig3d(*fields, size=64, cmap=None, norm=None): nc = fields[-1].shape[0] nf = len(fields) - fig, axes = plt.subplots(nc, nf, squeeze=False, figsize=(5 * nf, 4.25 * nc)) + colorbar_frac = 0.15 / (0.85 * nc + 0.15) + fig, axes = plt.subplots(nc, nf, squeeze=False, figsize=(4 * nf, 4 * nc * (1 + colorbar_frac))) - if cmap is None: - if (fields[-1] >= 0).all(): - cmap = 'viridis' - elif (fields[-1] <= 0).all(): - raise NotImplementedError - else: - cmap = 'RdBu_r' + def quantize(x): + return 2 ** round(log2(x), ndigits=1) - if norm is None: - def quantize(x): - return 2 ** round(log2(x), ndigits=1) + for f in range(nf): + all_non_neg = (fields[f] >= 0).all() + all_non_pos = (fields[f] <= 0).all() - l2, l1, h1, h2 = np.percentile(fields[-1], [2.5, 16, 84, 97.5]) - w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2 - - if (fields[-1] >= 0).all(): - if h1 > 0.1 * h2: - norm = Normalize(vmin=0, vmax=quantize(2 * h2)) + if cmap is None: + if all_non_neg: + cmap_ = 'viridis' + elif all_non_pos: + raise NotImplementedError else: - norm = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2)) - elif (fields[-1] <= 0).all(): - raise NotImplementedError + cmap_ = 'RdBu_r' else: - if w1 > 0.1 * w2: - vlim = quantize(2.5 * w1) - norm = Normalize(vmin=-vlim, vmax=vlim) - else: - vlim = quantize(w2) - norm = SymLogNorm(linthresh=0.1 * w1, vmin=-vlim, vmax=vlim) + cmap_ = cmap - for c in range(nc): - for f in range(nf): - axes[c, f].imshow(fields[f][c, 0, :size, :size], cmap=cmap, norm=norm) - plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax=axes) + if norm is None: + l2, l1, h1, h2 = np.percentile(fields[f], [2.5, 16, 84, 97.5]) + w1, w2 = (h1 - l1) / 2, (h2 - l2) / 2 + + if all_non_neg: + if h1 > 0.1 * h2: + norm_ = Normalize(vmin=0, vmax=quantize(2 * h2)) + else: + norm_ = LogNorm(vmin=quantize(0.5 * l2), vmax=quantize(2 * h2)) + elif all_non_pos: + raise NotImplementedError + else: + if w1 > 0.1 * w2: + vlim = quantize(2.5 * w1) + norm_ = Normalize(vmin=-vlim, vmax=vlim) + else: + vlim = quantize(w2) + norm_ = SymLogNorm(linthresh=0.1 * w1, vmin=-vlim, vmax=vlim) + else: + norm_ = norm + + for c in range(nc): + axes[c, f].imshow(fields[f][c, 0, :size, :size], cmap=cmap_, norm=norm_) + + plt.colorbar(ScalarMappable(norm=norm_, cmap=cmap_), ax=axes[:, f], + orientation='horizontal', fraction=colorbar_frac, pad=0.05) return fig diff --git a/map2map/train.py b/map2map/train.py index 5d70f11..2775911 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -83,10 +83,10 @@ def gpu_worker(local_rank, args): pin_memory=True ) - in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan + args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan model = getattr(models, args.model) - model = model(sum(in_chan) + args.noise_chan, sum(out_chan)) + model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan)) model.to(args.device) model = DistributedDataParallel(model, device_ids=[args.device], process_group=dist.new_group()) @@ -111,8 +111,8 @@ def gpu_worker(local_rank, args): if args.adv: adv_model = getattr(models, args.adv_model) adv_model = adv_model_wrapper(adv_model) - adv_model = adv_model(sum(in_chan + out_chan) - if args.cgan else sum(out_chan), 1) + adv_model = adv_model(sum(args.in_chan + args.out_chan) + if args.cgan else sum(args.out_chan), 1) adv_model.to(args.device) adv_model = DistributedDataParallel(adv_model, device_ids=[args.device], process_group=dist.new_group()) @@ -323,8 +323,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, 'real': epoch_loss[4], }, global_step=epoch+1) - skip_chan = sum(in_chan) if args.adv and args.cgan else 0 - args.logger.add_figure('fig/epoch/train', + skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0 + args.logger.add_figure('fig/epoch/train/in', + fig3d(narrow_like(input, output)[-1]), global_step =epoch+1) + args.logger.add_figure('fig/epoch/train/out', fig3d(output[-1, skip_chan:], target[-1, skip_chan:]), global_step =epoch+1) @@ -389,7 +391,9 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args): 'real': epoch_loss[4], }, global_step=epoch+1) - skip_chan = sum(in_chan) if args.adv and args.cgan else 0 + skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0 + args.logger.add_figure('fig/epoch/val/in', + fig3d(narrow_like(input, output)[-1]), global_step =epoch+1) args.logger.add_figure('fig/epoch/val', fig3d(output[-1, skip_chan:], target[-1, skip_chan:]), global_step =epoch+1)