Add spectral normalization

This commit is contained in:
Yin Li 2020-02-11 11:37:11 -05:00
parent b41e85eda5
commit 2e2adc761d
5 changed files with 39 additions and 10 deletions

View File

@ -68,6 +68,8 @@ def add_train_args(parser):
parser.add_argument('--adv-model', type=str, parser.add_argument('--adv-model', type=str,
help='enable adversary model from .models') help='enable adversary model from .models')
parser.add_argument('--adv-model-spectral-norm', action='store_true',
help='enable spectral normalization on the adversary model')
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str, parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
help='adversarial criterion from torch.nn') help='adversarial criterion from torch.nn')
parser.add_argument('--min-reduction', action='store_true', parser.add_argument('--min-reduction', action='store_true',

View File

@ -6,3 +6,6 @@ from .patchgan import PatchGAN, PatchGAN42
from .conv import narrow_like from .conv import narrow_like
from .dice import DiceLoss, dice_loss from .dice import DiceLoss, dice_loss
from .adversary import adv_model_wrapper, adv_criterion_wrapper
from .spectral_norm import add_spectral_norm, rm_spectral_norm

View File

@ -1,21 +1,21 @@
import torch import torch
def adv_model_wrapper(cls): def adv_model_wrapper(module):
"""Wrap an adversary model to also take lists of Tensors as input, """Wrap an adversary model to also take lists of Tensors as input,
to be concatenated along the batch dimension to be concatenated along the batch dimension
""" """
class newcls(cls): class new_module(module):
def forward(self, x): def forward(self, x):
if not isinstance(x, torch.Tensor): if not isinstance(x, torch.Tensor):
x = torch.cat(x, dim=0) x = torch.cat(x, dim=0)
return super().forward(x) return super().forward(x)
return newcls return new_module
def adv_criterion_wrapper(cls): def adv_criterion_wrapper(module):
"""Wrap an adversarial criterion to: """Wrap an adversarial criterion to:
* also take lists of Tensors as target, used to split the input Tensor * also take lists of Tensors as target, used to split the input Tensor
along the batch dimension along the batch dimension
@ -23,7 +23,7 @@ def adv_criterion_wrapper(cls):
* expand target shape as that of input * expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors * return a list of losses, one for each pair of input and target Tensors
""" """
class newcls(cls): class new_module(module):
def forward(self, input, target): def forward(self, input, target):
assert isinstance(input, torch.Tensor) assert isinstance(input, torch.Tensor)
@ -41,10 +41,12 @@ def adv_criterion_wrapper(cls):
if self.reduction == 'min': if self.reduction == 'min':
self.reduction = 'mean' # average over batches self.reduction = 'mean' # average over batches
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)] loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
self.reduction = 'min' self.reduction = 'min'
else: else:
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)] loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
return loss return loss
@ -58,4 +60,4 @@ def adv_criterion_wrapper(cls):
return torch.split(input, size, dim=0) return torch.split(input, size, dim=0)
return newcls return new_module

View File

@ -0,0 +1,19 @@
import torch.nn as nn
from torch.nn.utils import spectral_norm, remove_spectral_norm
def add_spectral_norm(module):
for name, child in module.named_children():
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
setattr(module, name, spectral_norm(child))
else:
add_spectral_norm(child)
def rm_spectral_norm(module):
for name, child in module.named_children():
if isinstance(child, (nn._ConvNd, nn.Linear)):
setattr(module, name, remove_spectral_norm(child))
else:
rm_spectral_norm(child)

View File

@ -13,8 +13,9 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset from .data import FieldDataset
from .data.figures import fig3d from .data.figures import fig3d
from . import models from . import models
from .models import narrow_like from .models import (narrow_like,
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm)
from .state import load_model_state_dict from .state import load_model_state_dict
@ -148,6 +149,8 @@ def gpu_worker(local_rank, node, args):
adv_model = adv_model_wrapper(adv_model) adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(args.in_chan + args.out_chan) adv_model = adv_model(sum(args.in_chan + args.out_chan)
if args.cgan else sum(args.out_chan), 1) if args.cgan else sum(args.out_chan), 1)
if args.adv_model_spectral_norm:
add_spectral_norm(adv_model)
adv_model.to(device) adv_model.to(device)
adv_model = DistributedDataParallel(adv_model, device_ids=[device], adv_model = DistributedDataParallel(adv_model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())