Add noise annealing duration

This commit is contained in:
Yin Li 2020-03-09 12:07:55 -04:00
parent 93d973b5c8
commit d01d0cee83
3 changed files with 6 additions and 3 deletions

View File

@ -87,6 +87,8 @@ def add_train_args(parser):
help='half-life (epoch) to anneal loss while enhancing adv-loss') help='half-life (epoch) to anneal loss while enhancing adv-loss')
parser.add_argument('--instance-noise', default=0, type=float, parser.add_argument('--instance-noise', default=0, type=float,
help='noise added to the adversary inputs to stabilize training') help='noise added to the adversary inputs to stabilize training')
parser.add_argument('--instance-noise-batches', default=1e4, type=float,
help='noise annealing duration')
parser.add_argument('--optimizer', default='Adam', type=str, parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim') help='optimizer from torch.optim')

View File

@ -4,11 +4,11 @@ import torch
class InstanceNoise: class InstanceNoise:
"""Instance noise, with a heuristic annealing schedule """Instance noise, with a heuristic annealing schedule
""" """
def __init__(self, init_std): def __init__(self, init_std, batches):
self.init_std = init_std self.init_std = init_std
self.anneal = 1 self.anneal = 1
self.ln2 = log(2) self.ln2 = log(2)
self.batches = 1e5 self.batches = batches
def std(self, adv_loss): def std(self, adv_loss):
self.anneal -= adv_loss / self.ln2 / self.batches self.anneal -= adv_loss / self.ln2 / self.batches

View File

@ -216,7 +216,8 @@ def gpu_worker(local_rank, node, args):
sys.stdout.flush() sys.stdout.flush()
if args.adv: if args.adv:
args.instance_noise = InstanceNoise(args.instance_noise) args.instance_noise = InstanceNoise(args.instance_noise,
args.instance_noise_batches)
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
if not args.div_data: if not args.div_data: