Remove smooth transition to adversarial loss

This commit is contained in:
Yin Li 2020-05-07 10:48:55 -04:00
parent f442dd59ba
commit 67e5ed9eb6
2 changed files with 10 additions and 6 deletions

View File

@ -1,4 +1,5 @@
import argparse import argparse
import warnings
from .train import ckpt_link from .train import ckpt_link
@ -91,7 +92,7 @@ def add_train_args(parser):
'e.g. 0.9 for label smoothing and 1 to disable') 'e.g. 0.9 for label smoothing and 1 to disable')
parser.add_argument('--loss-fraction', default=0.5, type=float, parser.add_argument('--loss-fraction', default=0.5, type=float,
help='final fraction of loss (vs adv-loss)') help='final fraction of loss (vs adv-loss)')
parser.add_argument('--loss-halflife', default=20, type=float, parser.add_argument('--loss-halflife', default=20, type=deprecated,
help='half-life (epoch) to anneal loss while enhancing adv-loss') help='half-life (epoch) to anneal loss while enhancing adv-loss')
parser.add_argument('--instance-noise', default=0, type=float, parser.add_argument('--instance-noise', default=0, type=float,
help='noise added to the adversary inputs to stabilize training') help='noise added to the adversary inputs to stabilize training')
@ -141,6 +142,10 @@ def str_list(s):
return s.split(',') return s.split(',')
def deprecated(s):
warnings.warn("deprecated argument", DeprecationWarning, stacklevel=2)
#def int_tuple(t): #def int_tuple(t):
# t = t.split(',') # t = t.split(',')
# t = tuple(int(i) for i in t) # t = tuple(int(i) for i in t)

View File

@ -348,11 +348,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
loss_adv, = adv_criterion(eval_out, real) loss_adv, = adv_criterion(eval_out, real)
epoch_loss[1] += loss_adv.item() epoch_loss[1] += loss_adv.item()
r = loss.item() / (loss_adv.item() + 1e-8) ratio = loss.item() / (loss_adv.item() + 1e-8)
f = args.loss_fraction frac = args.loss_fraction
e = epoch - args.adv_start if epoch >= args.adv_start:
d = 0.5 ** (e / args.loss_halflife) loss = frac * loss + (1 - frac) * ratio * loss_adv
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()