Add instance noise
This commit is contained in:
parent
1963530984
commit
ef993d83e9
@ -79,12 +79,14 @@ def add_train_args(parser):
|
|||||||
parser.add_argument('--adv-start', default=0, type=int,
|
parser.add_argument('--adv-start', default=0, type=int,
|
||||||
help='epoch to start adversarial training')
|
help='epoch to start adversarial training')
|
||||||
parser.add_argument('--adv-label-smoothing', default=1, type=float,
|
parser.add_argument('--adv-label-smoothing', default=1, type=float,
|
||||||
help='label of real samples for discriminator, '
|
help='label of real samples for the adversary model, '
|
||||||
'e.g. 0.9 for label smoothing and 1 to disable')
|
'e.g. 0.9 for label smoothing and 1 to disable')
|
||||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||||
help='final fraction of loss (vs adv-loss)')
|
help='final fraction of loss (vs adv-loss)')
|
||||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||||
|
parser.add_argument('--instance-noise', default=0, type=float,
|
||||||
|
help='noise added to the adversary inputs to stabilize training')
|
||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
|
@ -9,3 +9,4 @@ from .dice import DiceLoss, dice_loss
|
|||||||
|
|
||||||
from .adversary import adv_model_wrapper, adv_criterion_wrapper
|
from .adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||||
from .spectral_norm import add_spectral_norm, rm_spectral_norm
|
from .spectral_norm import add_spectral_norm, rm_spectral_norm
|
||||||
|
from .instance_noise import InstanceNoise
|
||||||
|
18
map2map/models/instance_noise.py
Normal file
18
map2map/models/instance_noise.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from math import log
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class InstanceNoise:
|
||||||
|
"""Instance noise, with a heuristic annealing schedule
|
||||||
|
"""
|
||||||
|
def __init__(self, init_std):
|
||||||
|
self.init_std = init_std
|
||||||
|
self.adv_loss_cum = 0
|
||||||
|
self.ln2 = log(2)
|
||||||
|
self.batches = 1e5
|
||||||
|
|
||||||
|
def std(self, adv_loss):
|
||||||
|
self.adv_loss_cum += adv_loss
|
||||||
|
anneal = 1 - self.adv_loss_cum / self.ln2 / self.batches
|
||||||
|
anneal = anneal if anneal > 0 else 0
|
||||||
|
std = anneal * self.init_std
|
||||||
|
return std
|
@ -19,7 +19,8 @@ from .data.figures import fig3d
|
|||||||
from . import models
|
from . import models
|
||||||
from .models import (narrow_like,
|
from .models import (narrow_like,
|
||||||
adv_model_wrapper, adv_criterion_wrapper,
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
add_spectral_norm, rm_spectral_norm)
|
add_spectral_norm, rm_spectral_norm,
|
||||||
|
InstanceNoise)
|
||||||
from .state import load_model_state_dict
|
from .state import load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
@ -210,6 +211,9 @@ def gpu_worker(local_rank, node, args):
|
|||||||
pprint(vars(args))
|
pprint(vars(args))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
if args.adv:
|
||||||
|
args.instance_noise = InstanceNoise(args.instance_noise)
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
@ -297,6 +301,16 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
if args.adv and epoch >= args.adv_start:
|
if args.adv and epoch >= args.adv_start:
|
||||||
|
try:
|
||||||
|
noise_std = args.instance_noise.std(adv_loss)
|
||||||
|
except NameError:
|
||||||
|
noise_std = args.instance_noise.std(0)
|
||||||
|
if noise_std > 0:
|
||||||
|
noise = noise_std * torch.randn_like(output)
|
||||||
|
output = output + noise.detach()
|
||||||
|
target = target + noise.detach()
|
||||||
|
del noise
|
||||||
|
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
@ -367,6 +381,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
'last': grads[1],
|
'last': grads[1],
|
||||||
}, global_step=batch)
|
}, global_step=batch)
|
||||||
|
|
||||||
|
if args.adv and epoch >= args.adv_start:
|
||||||
|
logger.add_scalar('instance_noise', noise_std,
|
||||||
|
global_step=batch)
|
||||||
|
|
||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user