Add instance noise
This commit is contained in:
parent
1963530984
commit
ef993d83e9
@ -79,12 +79,14 @@ def add_train_args(parser):
|
||||
parser.add_argument('--adv-start', default=0, type=int,
|
||||
help='epoch to start adversarial training')
|
||||
parser.add_argument('--adv-label-smoothing', default=1, type=float,
|
||||
help='label of real samples for discriminator, '
|
||||
help='label of real samples for the adversary model, '
|
||||
'e.g. 0.9 for label smoothing and 1 to disable')
|
||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||
help='final fraction of loss (vs adv-loss)')
|
||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||
parser.add_argument('--instance-noise', default=0, type=float,
|
||||
help='noise added to the adversary inputs to stabilize training')
|
||||
|
||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||
help='optimizer from torch.optim')
|
||||
|
@ -9,3 +9,4 @@ from .dice import DiceLoss, dice_loss
|
||||
|
||||
from .adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||
from .spectral_norm import add_spectral_norm, rm_spectral_norm
|
||||
from .instance_noise import InstanceNoise
|
||||
|
18
map2map/models/instance_noise.py
Normal file
18
map2map/models/instance_noise.py
Normal file
@ -0,0 +1,18 @@
|
||||
from math import log
|
||||
import torch
|
||||
|
||||
class InstanceNoise:
|
||||
"""Instance noise, with a heuristic annealing schedule
|
||||
"""
|
||||
def __init__(self, init_std):
|
||||
self.init_std = init_std
|
||||
self.adv_loss_cum = 0
|
||||
self.ln2 = log(2)
|
||||
self.batches = 1e5
|
||||
|
||||
def std(self, adv_loss):
|
||||
self.adv_loss_cum += adv_loss
|
||||
anneal = 1 - self.adv_loss_cum / self.ln2 / self.batches
|
||||
anneal = anneal if anneal > 0 else 0
|
||||
std = anneal * self.init_std
|
||||
return std
|
@ -19,7 +19,8 @@ from .data.figures import fig3d
|
||||
from . import models
|
||||
from .models import (narrow_like,
|
||||
adv_model_wrapper, adv_criterion_wrapper,
|
||||
add_spectral_norm, rm_spectral_norm)
|
||||
add_spectral_norm, rm_spectral_norm,
|
||||
InstanceNoise)
|
||||
from .state import load_model_state_dict
|
||||
|
||||
|
||||
@ -210,6 +211,9 @@ def gpu_worker(local_rank, node, args):
|
||||
pprint(vars(args))
|
||||
sys.stdout.flush()
|
||||
|
||||
if args.adv:
|
||||
args.instance_noise = InstanceNoise(args.instance_noise)
|
||||
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
if not args.div_data:
|
||||
train_sampler.set_epoch(epoch)
|
||||
@ -297,6 +301,16 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
try:
|
||||
noise_std = args.instance_noise.std(adv_loss)
|
||||
except NameError:
|
||||
noise_std = args.instance_noise.std(0)
|
||||
if noise_std > 0:
|
||||
noise = noise_std * torch.randn_like(output)
|
||||
output = output + noise.detach()
|
||||
target = target + noise.detach()
|
||||
del noise
|
||||
|
||||
if args.cgan:
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
@ -367,6 +381,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
'last': grads[1],
|
||||
}, global_step=batch)
|
||||
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
logger.add_scalar('instance_noise', noise_std,
|
||||
global_step=batch)
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
if rank == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user