From 67e5ed9eb682a85e0f5c1c0a916703746019e4c9 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 7 May 2020 10:48:55 -0400 Subject: [PATCH] Remove smooth transition to adversarial loss --- map2map/args.py | 7 ++++++- map2map/train.py | 9 ++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/map2map/args.py b/map2map/args.py index 0da57b6..bdecb0d 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -1,4 +1,5 @@ import argparse +import warnings from .train import ckpt_link @@ -91,7 +92,7 @@ def add_train_args(parser): 'e.g. 0.9 for label smoothing and 1 to disable') parser.add_argument('--loss-fraction', default=0.5, type=float, help='final fraction of loss (vs adv-loss)') - parser.add_argument('--loss-halflife', default=20, type=float, + parser.add_argument('--loss-halflife', default=20, type=deprecated, help='half-life (epoch) to anneal loss while enhancing adv-loss') parser.add_argument('--instance-noise', default=0, type=float, help='noise added to the adversary inputs to stabilize training') @@ -141,6 +142,10 @@ def str_list(s): return s.split(',') +def deprecated(s): + warnings.warn("deprecated argument", DeprecationWarning, stacklevel=2) + + #def int_tuple(t): # t = t.split(',') # t = tuple(int(i) for i in t) diff --git a/map2map/train.py b/map2map/train.py index ba23a72..6b7b244 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -348,11 +348,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, loss_adv, = adv_criterion(eval_out, real) epoch_loss[1] += loss_adv.item() - r = loss.item() / (loss_adv.item() + 1e-8) - f = args.loss_fraction - e = epoch - args.adv_start - d = 0.5 ** (e / args.loss_halflife) - loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv + ratio = loss.item() / (loss_adv.item() + 1e-8) + frac = args.loss_fraction + if epoch >= args.adv_start: + loss = frac * loss + (1 - frac) * ratio * loss_adv optimizer.zero_grad() loss.backward()