Add spectral normalization

This commit is contained in:
Yin Li 2020-02-11 11:37:11 -05:00
parent b41e85eda5
commit 2e2adc761d
5 changed files with 39 additions and 10 deletions

View file

@ -68,6 +68,8 @@ def add_train_args(parser):
parser.add_argument('--adv-model', type=str,
help='enable adversary model from .models')
parser.add_argument('--adv-model-spectral-norm', action='store_true',
help='enable spectral normalization on the adversary model')
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
help='adversarial criterion from torch.nn')
parser.add_argument('--min-reduction', action='store_true',

View file

@ -6,3 +6,6 @@ from .patchgan import PatchGAN, PatchGAN42
from .conv import narrow_like
from .dice import DiceLoss, dice_loss
from .adversary import adv_model_wrapper, adv_criterion_wrapper
from .spectral_norm import add_spectral_norm, rm_spectral_norm

View file

@ -1,21 +1,21 @@
import torch
def adv_model_wrapper(cls):
def adv_model_wrapper(module):
"""Wrap an adversary model to also take lists of Tensors as input,
to be concatenated along the batch dimension
"""
class newcls(cls):
class new_module(module):
def forward(self, x):
if not isinstance(x, torch.Tensor):
x = torch.cat(x, dim=0)
return super().forward(x)
return newcls
return new_module
def adv_criterion_wrapper(cls):
def adv_criterion_wrapper(module):
"""Wrap an adversarial criterion to:
* also take lists of Tensors as target, used to split the input Tensor
along the batch dimension
@ -23,7 +23,7 @@ def adv_criterion_wrapper(cls):
* expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors
"""
class newcls(cls):
class new_module(module):
def forward(self, input, target):
assert isinstance(input, torch.Tensor)
@ -41,10 +41,12 @@ def adv_criterion_wrapper(cls):
if self.reduction == 'min':
self.reduction = 'mean' # average over batches
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
self.reduction = 'min'
else:
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
return loss
@ -58,4 +60,4 @@ def adv_criterion_wrapper(cls):
return torch.split(input, size, dim=0)
return newcls
return new_module

View file

@ -0,0 +1,19 @@
import torch.nn as nn
from torch.nn.utils import spectral_norm, remove_spectral_norm
def add_spectral_norm(module):
for name, child in module.named_children():
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
setattr(module, name, spectral_norm(child))
else:
add_spectral_norm(child)
def rm_spectral_norm(module):
for name, child in module.named_children():
if isinstance(child, (nn._ConvNd, nn.Linear)):
setattr(module, name, remove_spectral_norm(child))
else:
rm_spectral_norm(child)

View file

@ -13,8 +13,9 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset
from .data.figures import fig3d
from . import models
from .models import narrow_like
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
from .models import (narrow_like,
adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm)
from .state import load_model_state_dict
@ -148,6 +149,8 @@ def gpu_worker(local_rank, node, args):
adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(args.in_chan + args.out_chan)
if args.cgan else sum(args.out_chan), 1)
if args.adv_model_spectral_norm:
add_spectral_norm(adv_model)
adv_model.to(device)
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
process_group=dist.new_group())