Add instance noise

This commit is contained in:
Yin Li 2020-03-03 14:59:15 -05:00
parent 1963530984
commit ef993d83e9
4 changed files with 41 additions and 2 deletions

View File

@ -79,12 +79,14 @@ def add_train_args(parser):
parser.add_argument('--adv-start', default=0, type=int, parser.add_argument('--adv-start', default=0, type=int,
help='epoch to start adversarial training') help='epoch to start adversarial training')
parser.add_argument('--adv-label-smoothing', default=1, type=float, parser.add_argument('--adv-label-smoothing', default=1, type=float,
help='label of real samples for discriminator, ' help='label of real samples for the adversary model, '
'e.g. 0.9 for label smoothing and 1 to disable') 'e.g. 0.9 for label smoothing and 1 to disable')
parser.add_argument('--loss-fraction', default=0.5, type=float, parser.add_argument('--loss-fraction', default=0.5, type=float,
help='final fraction of loss (vs adv-loss)') help='final fraction of loss (vs adv-loss)')
parser.add_argument('--loss-halflife', default=20, type=float, parser.add_argument('--loss-halflife', default=20, type=float,
help='half-life (epoch) to anneal loss while enhancing adv-loss') help='half-life (epoch) to anneal loss while enhancing adv-loss')
parser.add_argument('--instance-noise', default=0, type=float,
help='noise added to the adversary inputs to stabilize training')
parser.add_argument('--optimizer', default='Adam', type=str, parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim') help='optimizer from torch.optim')

View File

@ -9,3 +9,4 @@ from .dice import DiceLoss, dice_loss
from .adversary import adv_model_wrapper, adv_criterion_wrapper from .adversary import adv_model_wrapper, adv_criterion_wrapper
from .spectral_norm import add_spectral_norm, rm_spectral_norm from .spectral_norm import add_spectral_norm, rm_spectral_norm
from .instance_noise import InstanceNoise

View File

@ -0,0 +1,18 @@
from math import log
import torch
class InstanceNoise:
"""Instance noise, with a heuristic annealing schedule
"""
def __init__(self, init_std):
self.init_std = init_std
self.adv_loss_cum = 0
self.ln2 = log(2)
self.batches = 1e5
def std(self, adv_loss):
self.adv_loss_cum += adv_loss
anneal = 1 - self.adv_loss_cum / self.ln2 / self.batches
anneal = anneal if anneal > 0 else 0
std = anneal * self.init_std
return std

View File

@ -19,7 +19,8 @@ from .data.figures import fig3d
from . import models from . import models
from .models import (narrow_like, from .models import (narrow_like,
adv_model_wrapper, adv_criterion_wrapper, adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm) add_spectral_norm, rm_spectral_norm,
InstanceNoise)
from .state import load_model_state_dict from .state import load_model_state_dict
@ -210,6 +211,9 @@ def gpu_worker(local_rank, node, args):
pprint(vars(args)) pprint(vars(args))
sys.stdout.flush() sys.stdout.flush()
if args.adv:
args.instance_noise = InstanceNoise(args.instance_noise)
for epoch in range(start_epoch, args.epochs): for epoch in range(start_epoch, args.epochs):
if not args.div_data: if not args.div_data:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
@ -297,6 +301,16 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
epoch_loss[0] += loss.item() epoch_loss[0] += loss.item()
if args.adv and epoch >= args.adv_start: if args.adv and epoch >= args.adv_start:
try:
noise_std = args.instance_noise.std(adv_loss)
except NameError:
noise_std = args.instance_noise.std(0)
if noise_std > 0:
noise = noise_std * torch.randn_like(output)
output = output + noise.detach()
target = target + noise.detach()
del noise
if args.cgan: if args.cgan:
output = torch.cat([input, output], dim=1) output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1) target = torch.cat([input, target], dim=1)
@ -367,6 +381,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
'last': grads[1], 'last': grads[1],
}, global_step=batch) }, global_step=batch)
if args.adv and epoch >= args.adv_start:
logger.add_scalar('instance_noise', noise_std,
global_step=batch)
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size epoch_loss /= len(loader) * world_size
if rank == 0: if rank == 0: