Remove smooth transition to adversarial loss
This commit is contained in:
parent
f442dd59ba
commit
67e5ed9eb6
@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
from .train import ckpt_link
|
||||
|
||||
@ -91,7 +92,7 @@ def add_train_args(parser):
|
||||
'e.g. 0.9 for label smoothing and 1 to disable')
|
||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||
help='final fraction of loss (vs adv-loss)')
|
||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||
parser.add_argument('--loss-halflife', default=20, type=deprecated,
|
||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||
parser.add_argument('--instance-noise', default=0, type=float,
|
||||
help='noise added to the adversary inputs to stabilize training')
|
||||
@ -141,6 +142,10 @@ def str_list(s):
|
||||
return s.split(',')
|
||||
|
||||
|
||||
def deprecated(s):
|
||||
warnings.warn("deprecated argument", DeprecationWarning, stacklevel=2)
|
||||
|
||||
|
||||
#def int_tuple(t):
|
||||
# t = t.split(',')
|
||||
# t = tuple(int(i) for i in t)
|
||||
|
@ -348,11 +348,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
loss_adv, = adv_criterion(eval_out, real)
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
r = loss.item() / (loss_adv.item() + 1e-8)
|
||||
f = args.loss_fraction
|
||||
e = epoch - args.adv_start
|
||||
d = 0.5 ** (e / args.loss_halflife)
|
||||
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
|
||||
ratio = loss.item() / (loss_adv.item() + 1e-8)
|
||||
frac = args.loss_fraction
|
||||
if epoch >= args.adv_start:
|
||||
loss = frac * loss + (1 - frac) * ratio * loss_adv
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
Loading…
Reference in New Issue
Block a user