Swap generator & discriminator order in training

The reasoning is that updating generator first will free up the memory
taken by the graph of the model
This commit is contained in:
Yin Li 2020-01-22 18:59:55 -05:00
parent f99fd6b177
commit 862c9e75a0

View File

@ -230,6 +230,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
adv_model.train() adv_model.train()
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real # loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
# loss: generator (model) supervised loss
# loss_adv: generator (model) adversarial loss
# adv_loss: discriminator (adv_model) loss
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
for i, (input, target) in enumerate(loader): for i, (input, target) in enumerate(loader):
@ -242,6 +245,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
loss = criterion(output, target) loss = criterion(output, target)
epoch_loss[0] += loss.item() epoch_loss[0] += loss.item()
# generator adversarial loss
if args.adv: if args.adv:
if args.cgan: if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
@ -251,40 +255,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
output = torch.cat([input, output], dim=1) output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1) target = torch.cat([input, target], dim=1)
# discriminator eval_out = adv_model(output)
# real = torch.ones(1, dtype=torch.float32,
# outtgt = torch.cat([output.detach(), target], dim=0) device=args.device).expand_as(eval_out)
#
# eval_outtgt = adv_model(outtgt)
#
# fake = torch.zeros(1, dtype=torch.float32, device=args.device)
# fake = fake.expand_as(output.shape[0] + eval_outtgt.shape[1:])
# real = torch.ones(1, dtype=torch.float32, device=args.device)
# real = real.expand_as(target.shape[0] + eval_outtgt.shape[1:])
# fakereal = torch.cat([fake, real], dim=0)
eval_out = adv_model(output.detach())
fake = torch.zeros(1, dtype=torch.float32, fake = torch.zeros(1, dtype=torch.float32,
device=args.device).expand_as(eval_out) device=args.device).expand_as(eval_out)
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
epoch_loss[3] += adv_loss_fake.item()
eval_tgt = adv_model(target)
real = torch.ones(1, dtype=torch.float32,
device=args.device).expand_as(eval_tgt)
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item()
adv_optimizer.zero_grad()
adv_loss.backward()
adv_optimizer.step()
# generator adversarial loss
eval_out = adv_model(output)
loss_adv = adv_criterion(eval_out, real) # FIXME try min loss_adv = adv_criterion(eval_out, real) # FIXME try min
epoch_loss[1] += loss_adv.item() epoch_loss[1] += loss_adv.item()
@ -296,6 +271,23 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# discriminator
if args.adv:
eval_out = adv_model(output.detach())
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
epoch_loss[3] += adv_loss_fake.item()
eval_tgt = adv_model(target)
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item()
adv_optimizer.zero_grad()
adv_loss.backward()
adv_optimizer.step()
batch = epoch * len(loader) + i + 1 batch = epoch * len(loader) + i + 1
if batch % args.log_interval == 0: if batch % args.log_interval == 0:
dist.all_reduce(loss) dist.all_reduce(loss)
@ -334,7 +326,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
if args.adv: if args.adv:
adv_model.eval() adv_model.eval()
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device) epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
with torch.no_grad(): with torch.no_grad():