Change instance noise
Simpler linear annealing as in the blog post; different noise to output and target
This commit is contained in:
parent
873258d8a7
commit
29ab550032
@ -1,17 +1,14 @@
|
|||||||
from math import log
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
class InstanceNoise:
|
class InstanceNoise:
|
||||||
"""Instance noise, with a heuristic annealing schedule
|
"""Instance noise, with a linear decaying schedule
|
||||||
"""
|
"""
|
||||||
def __init__(self, init_std, batches):
|
def __init__(self, init_std, batches):
|
||||||
|
assert init_std >= 0, 'Noise std cannot be negative'
|
||||||
self.init_std = init_std
|
self.init_std = init_std
|
||||||
self.anneal = 1
|
self._std = init_std
|
||||||
self.ln2 = log(2)
|
|
||||||
self.batches = batches
|
self.batches = batches
|
||||||
|
|
||||||
def std(self, adv_loss):
|
def std(self):
|
||||||
self.anneal -= adv_loss / self.ln2 / self.batches
|
self._std -= self.init_std / self.batches
|
||||||
std = self.anneal * self.init_std
|
return max(self._std, 0)
|
||||||
std = std if std > 0 else 0
|
|
||||||
return std
|
|
||||||
|
@ -17,7 +17,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from .data import FieldDataset, GroupedRandomSampler
|
from .data import FieldDataset, GroupedRandomSampler
|
||||||
from .data.figures import plt_slices
|
from .data.figures import plt_slices
|
||||||
from . import models
|
from . import models
|
||||||
from .models import (narrow_cast, resample
|
from .models import (narrow_cast, resample,
|
||||||
adv_model_wrapper, adv_criterion_wrapper,
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
add_spectral_norm, rm_spectral_norm,
|
add_spectral_norm, rm_spectral_norm,
|
||||||
InstanceNoise)
|
InstanceNoise)
|
||||||
@ -347,13 +347,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
if args.adv and epoch >= args.adv_start:
|
if args.adv and epoch >= args.adv_start:
|
||||||
try:
|
noise_std = args.instance_noise.std()
|
||||||
noise_std = args.instance_noise.std(adv_loss)
|
|
||||||
except NameError:
|
|
||||||
noise_std = args.instance_noise.std(0)
|
|
||||||
if noise_std > 0:
|
if noise_std > 0:
|
||||||
noise = noise_std * torch.randn_like(output)
|
noise = noise_std * torch.randn_like(output)
|
||||||
output = output + noise.detach()
|
output = output + noise.detach()
|
||||||
|
noise = noise_std * torch.randn_like(target)
|
||||||
target = target + noise.detach()
|
target = target + noise.detach()
|
||||||
del noise
|
del noise
|
||||||
|
|
||||||
@ -426,7 +424,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
'last': grads[-1],
|
'last': grads[-1],
|
||||||
}, global_step=batch)
|
}, global_step=batch)
|
||||||
|
|
||||||
if args.adv and epoch >= args.adv_start:
|
if args.adv and epoch >= args.adv_start and noise_std > 0:
|
||||||
logger.add_scalar('instance_noise', noise_std,
|
logger.add_scalar('instance_noise', noise_std,
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user