Add wrappers of adversary model and adversarial loss

commit c0dafec94bb7d131938650027f84e5308bf16ffd
Author: Yin Li <eelregit@gmail.com>
Date:   Mon Feb 3 11:18:08 2020 -0600

    Fix bug

commit b470b873649515f4b8a1cac7b4b33181eac51199
Author: Yin Li <eelregit@gmail.com>
Date:   Mon Feb 3 09:39:08 2020 -0600

    Fix bug

commit 9f8f64b3510c72bfcf2a1236ba5285edf280701c
Author: Yin Li <eelregit@gmail.com>
Date:   Mon Feb 3 10:20:37 2020 -0500

    Add wrappers of adversary model and adversarial loss
This commit is contained in:
Yin Li 2020-02-03 11:21:49 -06:00
parent a237245514
commit 0a2fc9a9e9
4 changed files with 81 additions and 28 deletions

View File

@ -66,7 +66,9 @@ def add_train_args(parser):
parser.add_argument('--adv-model', type=str, parser.add_argument('--adv-model', type=str,
help='enable adversary model from .models') help='enable adversary model from .models')
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str, parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
help='adversary criterion from torch.nn') help='adversarial criterion from torch.nn')
parser.add_argument('--min-reduction', action='store_true',
help='enable minimum reduction in adversarial criterion')
parser.add_argument('--cgan', action='store_true', parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN') help='enable conditional GAN')
parser.add_argument('--adv-delay', default=0, type=int, parser.add_argument('--adv-delay', default=0, type=int,

View File

@ -0,0 +1,61 @@
import torch
def adv_model_wrapper(cls):
"""Wrap an adversary model to also take lists of Tensors as input,
to be concatenated along the batch dimension
"""
class newcls(cls):
def forward(self, x):
if not isinstance(x, torch.Tensor):
x = torch.cat(x, dim=0)
return super().forward(x)
return newcls
def adv_criterion_wrapper(cls):
"""Wrap an adversarial criterion to:
* also take lists of Tensors as target, used to split the input Tensor
along the batch dimension
* enable min reduction on input
* expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors
"""
class newcls(cls):
def forward(self, input, target):
assert isinstance(input, torch.Tensor)
if isinstance(target, torch.Tensor):
input = [input]
target = [target]
else:
input = self.split_input(input, target)
assert len(input) == len(target)
if self.reduction == 'min':
input = [torch.min(i).unsqueeze(0) for i in input]
target = [t.expand_as(i) for i, t in zip(input, target)]
if self.reduction == 'min':
self.reduction = 'mean' # average over batches
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
self.reduction = 'min'
else:
loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)]
return loss
@staticmethod
def split_input(input, target):
assert all(t.dim() == target[0].dim() > 0 for t in target)
if all(t.shape[0] == 1 for t in target):
size = input.shape[0] // len(target)
else:
size = [t.shape[0] for t in target]
return torch.split(input, size, dim=0)
return newcls

View File

@ -39,9 +39,9 @@ class ConvBlock(nn.Module):
in_chan, out_chan = self._setup_conv() in_chan, out_chan = self._setup_conv()
return nn.Conv3d(in_chan, out_chan, self.kernel_size) return nn.Conv3d(in_chan, out_chan, self.kernel_size)
elif l == 'B': elif l == 'B':
#return nn.BatchNorm3d(self.norm_chan) return nn.BatchNorm3d(self.norm_chan)
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True) #return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
return nn.InstanceNorm3d(self.norm_chan) #return nn.InstanceNorm3d(self.norm_chan)
elif l == 'A': elif l == 'A':
return nn.LeakyReLU() return nn.LeakyReLU()
else: else:

View File

@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from .data import FieldDataset from .data import FieldDataset
from . import models from . import models
from .models import narrow_like from .models import narrow_like
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
def node_worker(args): def node_worker(args):
@ -108,6 +109,7 @@ def gpu_worker(local_rank, args):
args.adv = args.adv_model is not None args.adv = args.adv_model is not None
if args.adv: if args.adv:
adv_model = getattr(models, args.adv_model) adv_model = getattr(models, args.adv_model)
adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(in_chan + out_chan) adv_model = adv_model(sum(in_chan + out_chan)
if args.cgan else sum(out_chan), 1) if args.cgan else sum(out_chan), 1)
adv_model.to(args.device) adv_model.to(args.device)
@ -115,7 +117,8 @@ def gpu_worker(local_rank, args):
process_group=dist.new_group()) process_group=dist.new_group())
adv_criterion = getattr(torch.nn, args.adv_criterion) adv_criterion = getattr(torch.nn, args.adv_criterion)
adv_criterion = adv_criterion() adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
adv_criterion.to(args.device) adv_criterion.to(args.device)
if args.adv_lr is None: if args.adv_lr is None:
@ -234,6 +237,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
# loss_adv: generator (model) adversarial loss # loss_adv: generator (model) adversarial loss
# adv_loss: discriminator (adv_model) loss # adv_loss: discriminator (adv_model) loss
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
real = torch.ones(1, dtype=torch.float32, device=args.device)
fake = torch.zeros(1, dtype=torch.float32, device=args.device)
for i, (input, target) in enumerate(loader): for i, (input, target) in enumerate(loader):
input = input.to(args.device, non_blocking=True) input = input.to(args.device, non_blocking=True)
@ -258,11 +263,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
target = torch.cat([input, target], dim=1) target = torch.cat([input, target], dim=1)
eval_out = adv_model(output) eval_out = adv_model(output)
real = torch.ones(1, dtype=torch.float32, loss_adv, = adv_criterion(eval_out, real)
device=args.device).expand_as(eval_out)
fake = torch.zeros(1, dtype=torch.float32,
device=args.device).expand_as(eval_out)
loss_adv = adv_criterion(eval_out, real) # FIXME try min
epoch_loss[1] += loss_adv.item() epoch_loss[1] += loss_adv.item()
if epoch >= args.adv_delay: if epoch >= args.adv_delay:
@ -275,14 +276,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
# discriminator # discriminator
if args.adv: if args.adv:
eval_out = adv_model(output.detach()) eval = adv_model([output.detach(), target])
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
epoch_loss[3] += adv_loss_fake.item() epoch_loss[3] += adv_loss_fake.item()
eval_tgt = adv_model(target)
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
epoch_loss[4] += adv_loss_real.item() epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item() epoch_loss[2] += adv_loss.item()
@ -329,6 +326,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
adv_model.eval() adv_model.eval()
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
fake = torch.zeros(1, dtype=torch.float32, device=args.device)
real = torch.ones(1, dtype=torch.float32, device=args.device)
with torch.no_grad(): with torch.no_grad():
for input, target in loader: for input, target in loader:
@ -353,25 +352,16 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
target = torch.cat([input, target], dim=1) target = torch.cat([input, target], dim=1)
# discriminator # discriminator
eval = adv_model([output, target])
eval_out = adv_model(output) adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
fake = torch.zeros(1, dtype=torch.float32,
device=args.device).expand_as(eval_out) # FIXME criterion wrapper: both D&G; min reduction; expand_as
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
epoch_loss[3] += adv_loss_fake.item() epoch_loss[3] += adv_loss_fake.item()
eval_tgt = adv_model(target)
real = torch.ones(1, dtype=torch.float32,
device=args.device).expand_as(eval_tgt)
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
epoch_loss[4] += adv_loss_real.item() epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item() epoch_loss[2] += adv_loss.item()
# generator adversarial loss # generator adversarial loss
eval_out, _ = adv_criterion.split_input(eval, [fake, real])
loss_adv = adv_criterion(eval_out, real) # FIXME try min loss_adv, = adv_criterion(eval_out, real)
epoch_loss[1] += loss_adv.item() epoch_loss[1] += loss_adv.item()
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)