Add spectral normalization
This commit is contained in:
parent
b41e85eda5
commit
2e2adc761d
@ -68,6 +68,8 @@ def add_train_args(parser):
|
|||||||
|
|
||||||
parser.add_argument('--adv-model', type=str,
|
parser.add_argument('--adv-model', type=str,
|
||||||
help='enable adversary model from .models')
|
help='enable adversary model from .models')
|
||||||
|
parser.add_argument('--adv-model-spectral-norm', action='store_true',
|
||||||
|
help='enable spectral normalization on the adversary model')
|
||||||
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
||||||
help='adversarial criterion from torch.nn')
|
help='adversarial criterion from torch.nn')
|
||||||
parser.add_argument('--min-reduction', action='store_true',
|
parser.add_argument('--min-reduction', action='store_true',
|
||||||
|
@ -6,3 +6,6 @@ from .patchgan import PatchGAN, PatchGAN42
|
|||||||
from .conv import narrow_like
|
from .conv import narrow_like
|
||||||
|
|
||||||
from .dice import DiceLoss, dice_loss
|
from .dice import DiceLoss, dice_loss
|
||||||
|
|
||||||
|
from .adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||||
|
from .spectral_norm import add_spectral_norm, rm_spectral_norm
|
||||||
|
@ -1,21 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def adv_model_wrapper(cls):
|
def adv_model_wrapper(module):
|
||||||
"""Wrap an adversary model to also take lists of Tensors as input,
|
"""Wrap an adversary model to also take lists of Tensors as input,
|
||||||
to be concatenated along the batch dimension
|
to be concatenated along the batch dimension
|
||||||
"""
|
"""
|
||||||
class newcls(cls):
|
class new_module(module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
x = torch.cat(x, dim=0)
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
return newcls
|
return new_module
|
||||||
|
|
||||||
|
|
||||||
def adv_criterion_wrapper(cls):
|
def adv_criterion_wrapper(module):
|
||||||
"""Wrap an adversarial criterion to:
|
"""Wrap an adversarial criterion to:
|
||||||
* also take lists of Tensors as target, used to split the input Tensor
|
* also take lists of Tensors as target, used to split the input Tensor
|
||||||
along the batch dimension
|
along the batch dimension
|
||||||
@ -23,7 +23,7 @@ def adv_criterion_wrapper(cls):
|
|||||||
* expand target shape as that of input
|
* expand target shape as that of input
|
||||||
* return a list of losses, one for each pair of input and target Tensors
|
* return a list of losses, one for each pair of input and target Tensors
|
||||||
"""
|
"""
|
||||||
class newcls(cls):
|
class new_module(module):
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
assert isinstance(input, torch.Tensor)
|
assert isinstance(input, torch.Tensor)
|
||||||
|
|
||||||
@ -41,10 +41,12 @@ def adv_criterion_wrapper(cls):
|
|||||||
|
|
||||||
if self.reduction == 'min':
|
if self.reduction == 'min':
|
||||||
self.reduction = 'mean' # average over batches
|
self.reduction = 'mean' # average over batches
|
||||||
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
|
loss = [super(new_module, self).forward(i, t)
|
||||||
|
for i, t in zip(input, target)]
|
||||||
self.reduction = 'min'
|
self.reduction = 'min'
|
||||||
else:
|
else:
|
||||||
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
|
loss = [super(new_module, self).forward(i, t)
|
||||||
|
for i, t in zip(input, target)]
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -58,4 +60,4 @@ def adv_criterion_wrapper(cls):
|
|||||||
|
|
||||||
return torch.split(input, size, dim=0)
|
return torch.split(input, size, dim=0)
|
||||||
|
|
||||||
return newcls
|
return new_module
|
||||||
|
19
map2map/models/spectral_norm.py
Normal file
19
map2map/models/spectral_norm.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.utils import spectral_norm, remove_spectral_norm
|
||||||
|
|
||||||
|
|
||||||
|
def add_spectral_norm(module):
|
||||||
|
for name, child in module.named_children():
|
||||||
|
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||||
|
setattr(module, name, spectral_norm(child))
|
||||||
|
else:
|
||||||
|
add_spectral_norm(child)
|
||||||
|
|
||||||
|
|
||||||
|
def rm_spectral_norm(module):
|
||||||
|
for name, child in module.named_children():
|
||||||
|
if isinstance(child, (nn._ConvNd, nn.Linear)):
|
||||||
|
setattr(module, name, remove_spectral_norm(child))
|
||||||
|
else:
|
||||||
|
rm_spectral_norm(child)
|
@ -13,8 +13,9 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from .data.figures import fig3d
|
from .data.figures import fig3d
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_like
|
from .models import (narrow_like,
|
||||||
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
|
add_spectral_norm, rm_spectral_norm)
|
||||||
from .state import load_model_state_dict
|
from .state import load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
@ -148,6 +149,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
adv_model = adv_model_wrapper(adv_model)
|
adv_model = adv_model_wrapper(adv_model)
|
||||||
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
||||||
if args.cgan else sum(args.out_chan), 1)
|
if args.cgan else sum(args.out_chan), 1)
|
||||||
|
if args.adv_model_spectral_norm:
|
||||||
|
add_spectral_norm(adv_model)
|
||||||
adv_model.to(device)
|
adv_model.to(device)
|
||||||
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
Loading…
Reference in New Issue
Block a user