Change instance noise

Simpler linear annealing as in the blog post;
different noise to output and target
This commit is contained in:
Yin Li 2020-07-11 01:45:54 -04:00
parent 873258d8a7
commit 29ab550032
2 changed files with 10 additions and 15 deletions

View File

@ -1,17 +1,14 @@
from math import log
import torch import torch
class InstanceNoise: class InstanceNoise:
"""Instance noise, with a heuristic annealing schedule """Instance noise, with a linear decaying schedule
""" """
def __init__(self, init_std, batches): def __init__(self, init_std, batches):
assert init_std >= 0, 'Noise std cannot be negative'
self.init_std = init_std self.init_std = init_std
self.anneal = 1 self._std = init_std
self.ln2 = log(2)
self.batches = batches self.batches = batches
def std(self, adv_loss): def std(self):
self.anneal -= adv_loss / self.ln2 / self.batches self._std -= self.init_std / self.batches
std = self.anneal * self.init_std return max(self._std, 0)
std = std if std > 0 else 0
return std

View File

@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset, GroupedRandomSampler from .data import FieldDataset, GroupedRandomSampler
from .data.figures import plt_slices from .data.figures import plt_slices
from . import models from . import models
from .models import (narrow_cast, resample from .models import (narrow_cast, resample,
adv_model_wrapper, adv_criterion_wrapper, adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm, add_spectral_norm, rm_spectral_norm,
InstanceNoise) InstanceNoise)
@ -347,13 +347,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
epoch_loss[0] += loss.item() epoch_loss[0] += loss.item()
if args.adv and epoch >= args.adv_start: if args.adv and epoch >= args.adv_start:
try: noise_std = args.instance_noise.std()
noise_std = args.instance_noise.std(adv_loss)
except NameError:
noise_std = args.instance_noise.std(0)
if noise_std > 0: if noise_std > 0:
noise = noise_std * torch.randn_like(output) noise = noise_std * torch.randn_like(output)
output = output + noise.detach() output = output + noise.detach()
noise = noise_std * torch.randn_like(target)
target = target + noise.detach() target = target + noise.detach()
del noise del noise
@ -426,7 +424,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
'last': grads[-1], 'last': grads[-1],
}, global_step=batch) }, global_step=batch)
if args.adv and epoch >= args.adv_start: if args.adv and epoch >= args.adv_start and noise_std > 0:
logger.add_scalar('instance_noise', noise_std, logger.add_scalar('instance_noise', noise_std,
global_step=batch) global_step=batch)