Change instance noise
Simpler linear annealing as in the blog post; different noise to output and target
This commit is contained in:
parent
873258d8a7
commit
29ab550032
@ -1,17 +1,14 @@
|
||||
from math import log
|
||||
import torch
|
||||
|
||||
class InstanceNoise:
|
||||
"""Instance noise, with a heuristic annealing schedule
|
||||
"""Instance noise, with a linear decaying schedule
|
||||
"""
|
||||
def __init__(self, init_std, batches):
|
||||
assert init_std >= 0, 'Noise std cannot be negative'
|
||||
self.init_std = init_std
|
||||
self.anneal = 1
|
||||
self.ln2 = log(2)
|
||||
self._std = init_std
|
||||
self.batches = batches
|
||||
|
||||
def std(self, adv_loss):
|
||||
self.anneal -= adv_loss / self.ln2 / self.batches
|
||||
std = self.anneal * self.init_std
|
||||
std = std if std > 0 else 0
|
||||
return std
|
||||
def std(self):
|
||||
self._std -= self.init_std / self.batches
|
||||
return max(self._std, 0)
|
||||
|
@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from .data import FieldDataset, GroupedRandomSampler
|
||||
from .data.figures import plt_slices
|
||||
from . import models
|
||||
from .models import (narrow_cast, resample
|
||||
from .models import (narrow_cast, resample,
|
||||
adv_model_wrapper, adv_criterion_wrapper,
|
||||
add_spectral_norm, rm_spectral_norm,
|
||||
InstanceNoise)
|
||||
@ -347,13 +347,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
try:
|
||||
noise_std = args.instance_noise.std(adv_loss)
|
||||
except NameError:
|
||||
noise_std = args.instance_noise.std(0)
|
||||
noise_std = args.instance_noise.std()
|
||||
if noise_std > 0:
|
||||
noise = noise_std * torch.randn_like(output)
|
||||
output = output + noise.detach()
|
||||
noise = noise_std * torch.randn_like(target)
|
||||
target = target + noise.detach()
|
||||
del noise
|
||||
|
||||
@ -426,7 +424,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
'last': grads[-1],
|
||||
}, global_step=batch)
|
||||
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
if args.adv and epoch >= args.adv_start and noise_std > 0:
|
||||
logger.add_scalar('instance_noise', noise_std,
|
||||
global_step=batch)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user