Swap generator & discriminator order in training
The reasoning is that updating generator first will free up the memory taken by the graph of the model
This commit is contained in:
parent
f99fd6b177
commit
862c9e75a0
@ -230,6 +230,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
adv_model.train()
|
||||
|
||||
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
||||
# loss: generator (model) supervised loss
|
||||
# loss_adv: generator (model) adversarial loss
|
||||
# adv_loss: discriminator (adv_model) loss
|
||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
||||
|
||||
for i, (input, target) in enumerate(loader):
|
||||
@ -242,6 +245,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
# generator adversarial loss
|
||||
if args.adv:
|
||||
if args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
@ -251,40 +255,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
# discriminator
|
||||
#
|
||||
# outtgt = torch.cat([output.detach(), target], dim=0)
|
||||
#
|
||||
# eval_outtgt = adv_model(outtgt)
|
||||
#
|
||||
# fake = torch.zeros(1, dtype=torch.float32, device=args.device)
|
||||
# fake = fake.expand_as(output.shape[0] + eval_outtgt.shape[1:])
|
||||
# real = torch.ones(1, dtype=torch.float32, device=args.device)
|
||||
# real = real.expand_as(target.shape[0] + eval_outtgt.shape[1:])
|
||||
# fakereal = torch.cat([fake, real], dim=0)
|
||||
|
||||
eval_out = adv_model(output.detach())
|
||||
eval_out = adv_model(output)
|
||||
real = torch.ones(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_out)
|
||||
fake = torch.zeros(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_out)
|
||||
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
|
||||
eval_tgt = adv_model(target)
|
||||
real = torch.ones(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_tgt)
|
||||
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
adv_optimizer.zero_grad()
|
||||
adv_loss.backward()
|
||||
adv_optimizer.step()
|
||||
|
||||
# generator adversarial loss
|
||||
|
||||
eval_out = adv_model(output)
|
||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
@ -296,6 +271,23 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# discriminator
|
||||
if args.adv:
|
||||
eval_out = adv_model(output.detach())
|
||||
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
|
||||
eval_tgt = adv_model(target)
|
||||
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
adv_optimizer.zero_grad()
|
||||
adv_loss.backward()
|
||||
adv_optimizer.step()
|
||||
|
||||
batch = epoch * len(loader) + i + 1
|
||||
if batch % args.log_interval == 0:
|
||||
dist.all_reduce(loss)
|
||||
@ -334,7 +326,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
||||
if args.adv:
|
||||
adv_model.eval()
|
||||
|
||||
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
Loading…
Reference in New Issue
Block a user