Change instance noise

Simpler linear annealing as in the blog post;
different noise to output and target
This commit is contained in:
Yin Li 2020-07-11 01:45:54 -04:00
parent 873258d8a7
commit 29ab550032
2 changed files with 10 additions and 15 deletions

View File

@ -1,17 +1,14 @@
from math import log
import torch
class InstanceNoise:
"""Instance noise, with a heuristic annealing schedule
"""Instance noise, with a linear decaying schedule
"""
def __init__(self, init_std, batches):
assert init_std >= 0, 'Noise std cannot be negative'
self.init_std = init_std
self.anneal = 1
self.ln2 = log(2)
self._std = init_std
self.batches = batches
def std(self, adv_loss):
self.anneal -= adv_loss / self.ln2 / self.batches
std = self.anneal * self.init_std
std = std if std > 0 else 0
return std
def std(self):
self._std -= self.init_std / self.batches
return max(self._std, 0)

View File

@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler
from .data.figures import plt_slices
from . import models
from .models import (narrow_cast, resample
from .models import (narrow_cast, resample,
adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm,
InstanceNoise)
@ -347,13 +347,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
epoch_loss[0] += loss.item()
if args.adv and epoch >= args.adv_start:
try:
noise_std = args.instance_noise.std(adv_loss)
except NameError:
noise_std = args.instance_noise.std(0)
noise_std = args.instance_noise.std()
if noise_std > 0:
noise = noise_std * torch.randn_like(output)
output = output + noise.detach()
noise = noise_std * torch.randn_like(target)
target = target + noise.detach()
del noise
@ -426,7 +424,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
'last': grads[-1],
}, global_step=batch)
if args.adv and epoch >= args.adv_start:
if args.adv and epoch >= args.adv_start and noise_std > 0:
logger.add_scalar('instance_noise', noise_std,
global_step=batch)