Add noise annealing duration
This commit is contained in:
parent
93d973b5c8
commit
d01d0cee83
@ -87,6 +87,8 @@ def add_train_args(parser):
|
|||||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||||
parser.add_argument('--instance-noise', default=0, type=float,
|
parser.add_argument('--instance-noise', default=0, type=float,
|
||||||
help='noise added to the adversary inputs to stabilize training')
|
help='noise added to the adversary inputs to stabilize training')
|
||||||
|
parser.add_argument('--instance-noise-batches', default=1e4, type=float,
|
||||||
|
help='noise annealing duration')
|
||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
|
@ -4,11 +4,11 @@ import torch
|
|||||||
class InstanceNoise:
|
class InstanceNoise:
|
||||||
"""Instance noise, with a heuristic annealing schedule
|
"""Instance noise, with a heuristic annealing schedule
|
||||||
"""
|
"""
|
||||||
def __init__(self, init_std):
|
def __init__(self, init_std, batches):
|
||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
self.anneal = 1
|
self.anneal = 1
|
||||||
self.ln2 = log(2)
|
self.ln2 = log(2)
|
||||||
self.batches = 1e5
|
self.batches = batches
|
||||||
|
|
||||||
def std(self, adv_loss):
|
def std(self, adv_loss):
|
||||||
self.anneal -= adv_loss / self.ln2 / self.batches
|
self.anneal -= adv_loss / self.ln2 / self.batches
|
||||||
|
@ -216,7 +216,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
if args.adv:
|
if args.adv:
|
||||||
args.instance_noise = InstanceNoise(args.instance_noise)
|
args.instance_noise = InstanceNoise(args.instance_noise,
|
||||||
|
args.instance_noise_batches)
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
|
Loading…
Reference in New Issue
Block a user