diff --git a/map2map/train.py b/map2map/train.py index ba38359..d476a96 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -230,6 +230,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, adv_model.train() # loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real + # loss: generator (model) supervised loss + # loss_adv: generator (model) adversarial loss + # adv_loss: discriminator (adv_model) loss epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) for i, (input, target) in enumerate(loader): @@ -242,6 +245,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, loss = criterion(output, target) epoch_loss[0] += loss.item() + # generator adversarial loss if args.adv: if args.cgan: if hasattr(model, 'scale_factor') and model.scale_factor != 1: @@ -251,40 +255,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, output = torch.cat([input, output], dim=1) target = torch.cat([input, target], dim=1) - # discriminator -# -# outtgt = torch.cat([output.detach(), target], dim=0) -# -# eval_outtgt = adv_model(outtgt) -# -# fake = torch.zeros(1, dtype=torch.float32, device=args.device) -# fake = fake.expand_as(output.shape[0] + eval_outtgt.shape[1:]) -# real = torch.ones(1, dtype=torch.float32, device=args.device) -# real = real.expand_as(target.shape[0] + eval_outtgt.shape[1:]) -# fakereal = torch.cat([fake, real], dim=0) - - eval_out = adv_model(output.detach()) + eval_out = adv_model(output) + real = torch.ones(1, dtype=torch.float32, + device=args.device).expand_as(eval_out) fake = torch.zeros(1, dtype=torch.float32, device=args.device).expand_as(eval_out) - adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min - epoch_loss[3] += adv_loss_fake.item() - - eval_tgt = adv_model(target) - real = torch.ones(1, dtype=torch.float32, - device=args.device).expand_as(eval_tgt) - adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min - epoch_loss[4] += adv_loss_real.item() - - adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) - epoch_loss[2] += adv_loss.item() - - adv_optimizer.zero_grad() - adv_loss.backward() - adv_optimizer.step() - - # generator adversarial loss - - eval_out = adv_model(output) loss_adv = adv_criterion(eval_out, real) # FIXME try min epoch_loss[1] += loss_adv.item() @@ -296,6 +271,23 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, loss.backward() optimizer.step() + # discriminator + if args.adv: + eval_out = adv_model(output.detach()) + adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min + epoch_loss[3] += adv_loss_fake.item() + + eval_tgt = adv_model(target) + adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min + epoch_loss[4] += adv_loss_real.item() + + adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) + epoch_loss[2] += adv_loss.item() + + adv_optimizer.zero_grad() + adv_loss.backward() + adv_optimizer.step() + batch = epoch * len(loader) + i + 1 if batch % args.log_interval == 0: dist.all_reduce(loss) @@ -334,7 +326,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args): if args.adv: adv_model.eval() - # loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) with torch.no_grad():