From 0a2fc9a9e954a19e909758475afb94367dafdde8 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Mon, 3 Feb 2020 11:21:49 -0600 Subject: [PATCH] Add wrappers of adversary model and adversarial loss commit c0dafec94bb7d131938650027f84e5308bf16ffd Author: Yin Li Date: Mon Feb 3 11:18:08 2020 -0600 Fix bug commit b470b873649515f4b8a1cac7b4b33181eac51199 Author: Yin Li Date: Mon Feb 3 09:39:08 2020 -0600 Fix bug commit 9f8f64b3510c72bfcf2a1236ba5285edf280701c Author: Yin Li Date: Mon Feb 3 10:20:37 2020 -0500 Add wrappers of adversary model and adversarial loss --- map2map/args.py | 4 ++- map2map/models/adversary.py | 61 +++++++++++++++++++++++++++++++++++++ map2map/models/conv.py | 4 +-- map2map/train.py | 40 +++++++++--------------- 4 files changed, 81 insertions(+), 28 deletions(-) create mode 100644 map2map/models/adversary.py diff --git a/map2map/args.py b/map2map/args.py index 768dd35..36d073e 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -66,7 +66,9 @@ def add_train_args(parser): parser.add_argument('--adv-model', type=str, help='enable adversary model from .models') parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str, - help='adversary criterion from torch.nn') + help='adversarial criterion from torch.nn') + parser.add_argument('--min-reduction', action='store_true', + help='enable minimum reduction in adversarial criterion') parser.add_argument('--cgan', action='store_true', help='enable conditional GAN') parser.add_argument('--adv-delay', default=0, type=int, diff --git a/map2map/models/adversary.py b/map2map/models/adversary.py new file mode 100644 index 0000000..25979d0 --- /dev/null +++ b/map2map/models/adversary.py @@ -0,0 +1,61 @@ +import torch + + +def adv_model_wrapper(cls): + """Wrap an adversary model to also take lists of Tensors as input, + to be concatenated along the batch dimension + """ + class newcls(cls): + def forward(self, x): + if not isinstance(x, torch.Tensor): + x = torch.cat(x, dim=0) + + return super().forward(x) + + return newcls + + +def adv_criterion_wrapper(cls): + """Wrap an adversarial criterion to: + * also take lists of Tensors as target, used to split the input Tensor + along the batch dimension + * enable min reduction on input + * expand target shape as that of input + * return a list of losses, one for each pair of input and target Tensors + """ + class newcls(cls): + def forward(self, input, target): + assert isinstance(input, torch.Tensor) + + if isinstance(target, torch.Tensor): + input = [input] + target = [target] + else: + input = self.split_input(input, target) + assert len(input) == len(target) + + if self.reduction == 'min': + input = [torch.min(i).unsqueeze(0) for i in input] + + target = [t.expand_as(i) for i, t in zip(input, target)] + + if self.reduction == 'min': + self.reduction = 'mean' # average over batches + loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)] + self.reduction = 'min' + else: + loss = [super(newcls, self).forward(i, t) for i, t in zip(input, target)] + + return loss + + @staticmethod + def split_input(input, target): + assert all(t.dim() == target[0].dim() > 0 for t in target) + if all(t.shape[0] == 1 for t in target): + size = input.shape[0] // len(target) + else: + size = [t.shape[0] for t in target] + + return torch.split(input, size, dim=0) + + return newcls diff --git a/map2map/models/conv.py b/map2map/models/conv.py index aa3a548..b30976c 100644 --- a/map2map/models/conv.py +++ b/map2map/models/conv.py @@ -39,9 +39,9 @@ class ConvBlock(nn.Module): in_chan, out_chan = self._setup_conv() return nn.Conv3d(in_chan, out_chan, self.kernel_size) elif l == 'B': - #return nn.BatchNorm3d(self.norm_chan) + return nn.BatchNorm3d(self.norm_chan) #return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True) - return nn.InstanceNorm3d(self.norm_chan) + #return nn.InstanceNorm3d(self.norm_chan) elif l == 'A': return nn.LeakyReLU() else: diff --git a/map2map/train.py b/map2map/train.py index 75d5378..7705ed4 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from .data import FieldDataset from . import models from .models import narrow_like +from .models.adversary import adv_model_wrapper, adv_criterion_wrapper def node_worker(args): @@ -108,6 +109,7 @@ def gpu_worker(local_rank, args): args.adv = args.adv_model is not None if args.adv: adv_model = getattr(models, args.adv_model) + adv_model = adv_model_wrapper(adv_model) adv_model = adv_model(sum(in_chan + out_chan) if args.cgan else sum(out_chan), 1) adv_model.to(args.device) @@ -115,7 +117,8 @@ def gpu_worker(local_rank, args): process_group=dist.new_group()) adv_criterion = getattr(torch.nn, args.adv_criterion) - adv_criterion = adv_criterion() + adv_criterion = adv_criterion_wrapper(adv_criterion) + adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean') adv_criterion.to(args.device) if args.adv_lr is None: @@ -234,6 +237,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, # loss_adv: generator (model) adversarial loss # adv_loss: discriminator (adv_model) loss epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) + real = torch.ones(1, dtype=torch.float32, device=args.device) + fake = torch.zeros(1, dtype=torch.float32, device=args.device) for i, (input, target) in enumerate(loader): input = input.to(args.device, non_blocking=True) @@ -258,11 +263,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, target = torch.cat([input, target], dim=1) eval_out = adv_model(output) - real = torch.ones(1, dtype=torch.float32, - device=args.device).expand_as(eval_out) - fake = torch.zeros(1, dtype=torch.float32, - device=args.device).expand_as(eval_out) - loss_adv = adv_criterion(eval_out, real) # FIXME try min + loss_adv, = adv_criterion(eval_out, real) epoch_loss[1] += loss_adv.item() if epoch >= args.adv_delay: @@ -275,14 +276,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, # discriminator if args.adv: - eval_out = adv_model(output.detach()) - adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min + eval = adv_model([output.detach(), target]) + adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real]) epoch_loss[3] += adv_loss_fake.item() - - eval_tgt = adv_model(target) - adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min epoch_loss[4] += adv_loss_real.item() - adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) epoch_loss[2] += adv_loss.item() @@ -329,6 +326,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args): adv_model.eval() epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) + fake = torch.zeros(1, dtype=torch.float32, device=args.device) + real = torch.ones(1, dtype=torch.float32, device=args.device) with torch.no_grad(): for input, target in loader: @@ -353,25 +352,16 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args): target = torch.cat([input, target], dim=1) # discriminator - - eval_out = adv_model(output) - fake = torch.zeros(1, dtype=torch.float32, - device=args.device).expand_as(eval_out) # FIXME criterion wrapper: both D&G; min reduction; expand_as - adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min + eval = adv_model([output, target]) + adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real]) epoch_loss[3] += adv_loss_fake.item() - - eval_tgt = adv_model(target) - real = torch.ones(1, dtype=torch.float32, - device=args.device).expand_as(eval_tgt) - adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min epoch_loss[4] += adv_loss_real.item() - adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) epoch_loss[2] += adv_loss.item() # generator adversarial loss - - loss_adv = adv_criterion(eval_out, real) # FIXME try min + eval_out, _ = adv_criterion.split_input(eval, [fake, real]) + loss_adv, = adv_criterion(eval_out, real) epoch_loss[1] += loss_adv.item() dist.all_reduce(epoch_loss)