Add wrappers of adversary model and adversarial loss
commit c0dafec94bb7d131938650027f84e5308bf16ffd Author: Yin Li <eelregit@gmail.com> Date: Mon Feb 3 11:18:08 2020 -0600 Fix bug commit b470b873649515f4b8a1cac7b4b33181eac51199 Author: Yin Li <eelregit@gmail.com> Date: Mon Feb 3 09:39:08 2020 -0600 Fix bug commit 9f8f64b3510c72bfcf2a1236ba5285edf280701c Author: Yin Li <eelregit@gmail.com> Date: Mon Feb 3 10:20:37 2020 -0500 Add wrappers of adversary model and adversarial loss
This commit is contained in:
parent
a237245514
commit
0a2fc9a9e9
4 changed files with 81 additions and 28 deletions
|
@ -66,7 +66,9 @@ def add_train_args(parser):
|
|||
parser.add_argument('--adv-model', type=str,
|
||||
help='enable adversary model from .models')
|
||||
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
||||
help='adversary criterion from torch.nn')
|
||||
help='adversarial criterion from torch.nn')
|
||||
parser.add_argument('--min-reduction', action='store_true',
|
||||
help='enable minimum reduction in adversarial criterion')
|
||||
parser.add_argument('--cgan', action='store_true',
|
||||
help='enable conditional GAN')
|
||||
parser.add_argument('--adv-delay', default=0, type=int,
|
||||
|
|
61
map2map/models/adversary.py
Normal file
61
map2map/models/adversary.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
import torch
|
||||
|
||||
|
||||
def adv_model_wrapper(cls):
|
||||
"""Wrap an adversary model to also take lists of Tensors as input,
|
||||
to be concatenated along the batch dimension
|
||||
"""
|
||||
class newcls(cls):
|
||||
def forward(self, x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
return newcls
|
||||
|
||||
|
||||
def adv_criterion_wrapper(cls):
|
||||
"""Wrap an adversarial criterion to:
|
||||
* also take lists of Tensors as target, used to split the input Tensor
|
||||
along the batch dimension
|
||||
* enable min reduction on input
|
||||
* expand target shape as that of input
|
||||
* return a list of losses, one for each pair of input and target Tensors
|
||||
"""
|
||||
class newcls(cls):
|
||||
def forward(self, input, target):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
|
||||
if isinstance(target, torch.Tensor):
|
||||
input = [input]
|
||||
target = [target]
|
||||
else:
|
||||
input = self.split_input(input, target)
|
||||
assert len(input) == len(target)
|
||||
|
||||
if self.reduction == 'min':
|
||||
input = [torch.min(i).unsqueeze(0) for i in input]
|
||||
|
||||
target = [t.expand_as(i) for i, t in zip(input, target)]
|
||||
|
||||
if self.reduction == 'min':
|
||||
self.reduction = 'mean' # average over batches
|
||||
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
|
||||
self.reduction = 'min'
|
||||
else:
|
||||
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def split_input(input, target):
|
||||
assert all(t.dim() == target[0].dim() > 0 for t in target)
|
||||
if all(t.shape[0] == 1 for t in target):
|
||||
size = input.shape[0] // len(target)
|
||||
else:
|
||||
size = [t.shape[0] for t in target]
|
||||
|
||||
return torch.split(input, size, dim=0)
|
||||
|
||||
return newcls
|
|
@ -39,9 +39,9 @@ class ConvBlock(nn.Module):
|
|||
in_chan, out_chan = self._setup_conv()
|
||||
return nn.Conv3d(in_chan, out_chan, self.kernel_size)
|
||||
elif l == 'B':
|
||||
#return nn.BatchNorm3d(self.norm_chan)
|
||||
return nn.BatchNorm3d(self.norm_chan)
|
||||
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
|
||||
return nn.InstanceNorm3d(self.norm_chan)
|
||||
#return nn.InstanceNorm3d(self.norm_chan)
|
||||
elif l == 'A':
|
||||
return nn.LeakyReLU()
|
||||
else:
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
from .data import FieldDataset
|
||||
from . import models
|
||||
from .models import narrow_like
|
||||
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||
|
||||
|
||||
def node_worker(args):
|
||||
|
@ -108,6 +109,7 @@ def gpu_worker(local_rank, args):
|
|||
args.adv = args.adv_model is not None
|
||||
if args.adv:
|
||||
adv_model = getattr(models, args.adv_model)
|
||||
adv_model = adv_model_wrapper(adv_model)
|
||||
adv_model = adv_model(sum(in_chan + out_chan)
|
||||
if args.cgan else sum(out_chan), 1)
|
||||
adv_model.to(args.device)
|
||||
|
@ -115,7 +117,8 @@ def gpu_worker(local_rank, args):
|
|||
process_group=dist.new_group())
|
||||
|
||||
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
||||
adv_criterion = adv_criterion()
|
||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
||||
adv_criterion.to(args.device)
|
||||
|
||||
if args.adv_lr is None:
|
||||
|
@ -234,6 +237,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
# loss_adv: generator (model) adversarial loss
|
||||
# adv_loss: discriminator (adv_model) loss
|
||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
||||
real = torch.ones(1, dtype=torch.float32, device=args.device)
|
||||
fake = torch.zeros(1, dtype=torch.float32, device=args.device)
|
||||
|
||||
for i, (input, target) in enumerate(loader):
|
||||
input = input.to(args.device, non_blocking=True)
|
||||
|
@ -258,11 +263,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
eval_out = adv_model(output)
|
||||
real = torch.ones(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_out)
|
||||
fake = torch.zeros(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_out)
|
||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||
loss_adv, = adv_criterion(eval_out, real)
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
if epoch >= args.adv_delay:
|
||||
|
@ -275,14 +276,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
|
||||
# discriminator
|
||||
if args.adv:
|
||||
eval_out = adv_model(output.detach())
|
||||
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
||||
eval = adv_model([output.detach(), target])
|
||||
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
|
||||
eval_tgt = adv_model(target)
|
||||
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
|
@ -329,6 +326,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
|||
adv_model.eval()
|
||||
|
||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
||||
fake = torch.zeros(1, dtype=torch.float32, device=args.device)
|
||||
real = torch.ones(1, dtype=torch.float32, device=args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for input, target in loader:
|
||||
|
@ -353,25 +352,16 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
|||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
# discriminator
|
||||
|
||||
eval_out = adv_model(output)
|
||||
fake = torch.zeros(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_out) # FIXME criterion wrapper: both D&G; min reduction; expand_as
|
||||
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
||||
eval = adv_model([output, target])
|
||||
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
|
||||
eval_tgt = adv_model(target)
|
||||
real = torch.ones(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_tgt)
|
||||
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
# generator adversarial loss
|
||||
|
||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||
eval_out, _ = adv_criterion.split_input(eval, [fake, real])
|
||||
loss_adv, = adv_criterion(eval_out, real)
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
|
|
Loading…
Reference in a new issue