Remove smooth transition to adversarial loss
This commit is contained in:
parent
f442dd59ba
commit
67e5ed9eb6
@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import warnings
|
||||||
|
|
||||||
from .train import ckpt_link
|
from .train import ckpt_link
|
||||||
|
|
||||||
@ -91,7 +92,7 @@ def add_train_args(parser):
|
|||||||
'e.g. 0.9 for label smoothing and 1 to disable')
|
'e.g. 0.9 for label smoothing and 1 to disable')
|
||||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||||
help='final fraction of loss (vs adv-loss)')
|
help='final fraction of loss (vs adv-loss)')
|
||||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
parser.add_argument('--loss-halflife', default=20, type=deprecated,
|
||||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||||
parser.add_argument('--instance-noise', default=0, type=float,
|
parser.add_argument('--instance-noise', default=0, type=float,
|
||||||
help='noise added to the adversary inputs to stabilize training')
|
help='noise added to the adversary inputs to stabilize training')
|
||||||
@ -141,6 +142,10 @@ def str_list(s):
|
|||||||
return s.split(',')
|
return s.split(',')
|
||||||
|
|
||||||
|
|
||||||
|
def deprecated(s):
|
||||||
|
warnings.warn("deprecated argument", DeprecationWarning, stacklevel=2)
|
||||||
|
|
||||||
|
|
||||||
#def int_tuple(t):
|
#def int_tuple(t):
|
||||||
# t = t.split(',')
|
# t = t.split(',')
|
||||||
# t = tuple(int(i) for i in t)
|
# t = tuple(int(i) for i in t)
|
||||||
|
@ -348,11 +348,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
loss_adv, = adv_criterion(eval_out, real)
|
loss_adv, = adv_criterion(eval_out, real)
|
||||||
epoch_loss[1] += loss_adv.item()
|
epoch_loss[1] += loss_adv.item()
|
||||||
|
|
||||||
r = loss.item() / (loss_adv.item() + 1e-8)
|
ratio = loss.item() / (loss_adv.item() + 1e-8)
|
||||||
f = args.loss_fraction
|
frac = args.loss_fraction
|
||||||
e = epoch - args.adv_start
|
if epoch >= args.adv_start:
|
||||||
d = 0.5 ** (e / args.loss_halflife)
|
loss = frac * loss + (1 - frac) * ratio * loss_adv
|
||||||
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
Loading…
Reference in New Issue
Block a user