Add spectral normalization
This commit is contained in:
parent
b41e85eda5
commit
2e2adc761d
5 changed files with 39 additions and 10 deletions
|
@ -68,6 +68,8 @@ def add_train_args(parser):
|
|||
|
||||
parser.add_argument('--adv-model', type=str,
|
||||
help='enable adversary model from .models')
|
||||
parser.add_argument('--adv-model-spectral-norm', action='store_true',
|
||||
help='enable spectral normalization on the adversary model')
|
||||
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
||||
help='adversarial criterion from torch.nn')
|
||||
parser.add_argument('--min-reduction', action='store_true',
|
||||
|
|
|
@ -6,3 +6,6 @@ from .patchgan import PatchGAN, PatchGAN42
|
|||
from .conv import narrow_like
|
||||
|
||||
from .dice import DiceLoss, dice_loss
|
||||
|
||||
from .adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||
from .spectral_norm import add_spectral_norm, rm_spectral_norm
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
import torch
|
||||
|
||||
|
||||
def adv_model_wrapper(cls):
|
||||
def adv_model_wrapper(module):
|
||||
"""Wrap an adversary model to also take lists of Tensors as input,
|
||||
to be concatenated along the batch dimension
|
||||
"""
|
||||
class newcls(cls):
|
||||
class new_module(module):
|
||||
def forward(self, x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
return newcls
|
||||
return new_module
|
||||
|
||||
|
||||
def adv_criterion_wrapper(cls):
|
||||
def adv_criterion_wrapper(module):
|
||||
"""Wrap an adversarial criterion to:
|
||||
* also take lists of Tensors as target, used to split the input Tensor
|
||||
along the batch dimension
|
||||
|
@ -23,7 +23,7 @@ def adv_criterion_wrapper(cls):
|
|||
* expand target shape as that of input
|
||||
* return a list of losses, one for each pair of input and target Tensors
|
||||
"""
|
||||
class newcls(cls):
|
||||
class new_module(module):
|
||||
def forward(self, input, target):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
|
||||
|
@ -41,10 +41,12 @@ def adv_criterion_wrapper(cls):
|
|||
|
||||
if self.reduction == 'min':
|
||||
self.reduction = 'mean' # average over batches
|
||||
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
|
||||
loss = [super(new_module, self).forward(i, t)
|
||||
for i, t in zip(input, target)]
|
||||
self.reduction = 'min'
|
||||
else:
|
||||
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
|
||||
loss = [super(new_module, self).forward(i, t)
|
||||
for i, t in zip(input, target)]
|
||||
|
||||
return loss
|
||||
|
||||
|
@ -58,4 +60,4 @@ def adv_criterion_wrapper(cls):
|
|||
|
||||
return torch.split(input, size, dim=0)
|
||||
|
||||
return newcls
|
||||
return new_module
|
||||
|
|
19
map2map/models/spectral_norm.py
Normal file
19
map2map/models/spectral_norm.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import torch.nn as nn
|
||||
from torch.nn.utils import spectral_norm, remove_spectral_norm
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
setattr(module, name, spectral_norm(child))
|
||||
else:
|
||||
add_spectral_norm(child)
|
||||
|
||||
|
||||
def rm_spectral_norm(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, (nn._ConvNd, nn.Linear)):
|
||||
setattr(module, name, remove_spectral_norm(child))
|
||||
else:
|
||||
rm_spectral_norm(child)
|
|
@ -13,8 +13,9 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
from .data import FieldDataset
|
||||
from .data.figures import fig3d
|
||||
from . import models
|
||||
from .models import narrow_like
|
||||
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||
from .models import (narrow_like,
|
||||
adv_model_wrapper, adv_criterion_wrapper,
|
||||
add_spectral_norm, rm_spectral_norm)
|
||||
from .state import load_model_state_dict
|
||||
|
||||
|
||||
|
@ -148,6 +149,8 @@ def gpu_worker(local_rank, node, args):
|
|||
adv_model = adv_model_wrapper(adv_model)
|
||||
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
||||
if args.cgan else sum(args.out_chan), 1)
|
||||
if args.adv_model_spectral_norm:
|
||||
add_spectral_norm(adv_model)
|
||||
adv_model.to(device)
|
||||
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
|
Loading…
Reference in a new issue