Add noise annealing duration
This commit is contained in:
parent
93d973b5c8
commit
d01d0cee83
@ -87,6 +87,8 @@ def add_train_args(parser):
|
||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||
parser.add_argument('--instance-noise', default=0, type=float,
|
||||
help='noise added to the adversary inputs to stabilize training')
|
||||
parser.add_argument('--instance-noise-batches', default=1e4, type=float,
|
||||
help='noise annealing duration')
|
||||
|
||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||
help='optimizer from torch.optim')
|
||||
|
@ -4,11 +4,11 @@ import torch
|
||||
class InstanceNoise:
|
||||
"""Instance noise, with a heuristic annealing schedule
|
||||
"""
|
||||
def __init__(self, init_std):
|
||||
def __init__(self, init_std, batches):
|
||||
self.init_std = init_std
|
||||
self.anneal = 1
|
||||
self.ln2 = log(2)
|
||||
self.batches = 1e5
|
||||
self.batches = batches
|
||||
|
||||
def std(self, adv_loss):
|
||||
self.anneal -= adv_loss / self.ln2 / self.batches
|
||||
|
@ -216,7 +216,8 @@ def gpu_worker(local_rank, node, args):
|
||||
sys.stdout.flush()
|
||||
|
||||
if args.adv:
|
||||
args.instance_noise = InstanceNoise(args.instance_noise)
|
||||
args.instance_noise = InstanceNoise(args.instance_noise,
|
||||
args.instance_noise_batches)
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
if not args.div_data:
|
||||
|
Loading…
Reference in New Issue
Block a user