Swap generator & discriminator order in training
The reasoning is that updating generator first will free up the memory taken by the graph of the model
This commit is contained in:
parent
f99fd6b177
commit
862c9e75a0
@ -230,6 +230,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
adv_model.train()
|
adv_model.train()
|
||||||
|
|
||||||
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
||||||
|
# loss: generator (model) supervised loss
|
||||||
|
# loss_adv: generator (model) adversarial loss
|
||||||
|
# adv_loss: discriminator (adv_model) loss
|
||||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
||||||
|
|
||||||
for i, (input, target) in enumerate(loader):
|
for i, (input, target) in enumerate(loader):
|
||||||
@ -242,6 +245,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
|
# generator adversarial loss
|
||||||
if args.adv:
|
if args.adv:
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
@ -251,40 +255,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
|
|
||||||
# discriminator
|
eval_out = adv_model(output)
|
||||||
#
|
real = torch.ones(1, dtype=torch.float32,
|
||||||
# outtgt = torch.cat([output.detach(), target], dim=0)
|
device=args.device).expand_as(eval_out)
|
||||||
#
|
|
||||||
# eval_outtgt = adv_model(outtgt)
|
|
||||||
#
|
|
||||||
# fake = torch.zeros(1, dtype=torch.float32, device=args.device)
|
|
||||||
# fake = fake.expand_as(output.shape[0] + eval_outtgt.shape[1:])
|
|
||||||
# real = torch.ones(1, dtype=torch.float32, device=args.device)
|
|
||||||
# real = real.expand_as(target.shape[0] + eval_outtgt.shape[1:])
|
|
||||||
# fakereal = torch.cat([fake, real], dim=0)
|
|
||||||
|
|
||||||
eval_out = adv_model(output.detach())
|
|
||||||
fake = torch.zeros(1, dtype=torch.float32,
|
fake = torch.zeros(1, dtype=torch.float32,
|
||||||
device=args.device).expand_as(eval_out)
|
device=args.device).expand_as(eval_out)
|
||||||
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
|
||||||
epoch_loss[3] += adv_loss_fake.item()
|
|
||||||
|
|
||||||
eval_tgt = adv_model(target)
|
|
||||||
real = torch.ones(1, dtype=torch.float32,
|
|
||||||
device=args.device).expand_as(eval_tgt)
|
|
||||||
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
|
||||||
epoch_loss[4] += adv_loss_real.item()
|
|
||||||
|
|
||||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
|
||||||
epoch_loss[2] += adv_loss.item()
|
|
||||||
|
|
||||||
adv_optimizer.zero_grad()
|
|
||||||
adv_loss.backward()
|
|
||||||
adv_optimizer.step()
|
|
||||||
|
|
||||||
# generator adversarial loss
|
|
||||||
|
|
||||||
eval_out = adv_model(output)
|
|
||||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||||
epoch_loss[1] += loss_adv.item()
|
epoch_loss[1] += loss_adv.item()
|
||||||
|
|
||||||
@ -296,6 +271,23 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# discriminator
|
||||||
|
if args.adv:
|
||||||
|
eval_out = adv_model(output.detach())
|
||||||
|
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
||||||
|
epoch_loss[3] += adv_loss_fake.item()
|
||||||
|
|
||||||
|
eval_tgt = adv_model(target)
|
||||||
|
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
||||||
|
epoch_loss[4] += adv_loss_real.item()
|
||||||
|
|
||||||
|
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||||
|
epoch_loss[2] += adv_loss.item()
|
||||||
|
|
||||||
|
adv_optimizer.zero_grad()
|
||||||
|
adv_loss.backward()
|
||||||
|
adv_optimizer.step()
|
||||||
|
|
||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
if batch % args.log_interval == 0:
|
if batch % args.log_interval == 0:
|
||||||
dist.all_reduce(loss)
|
dist.all_reduce(loss)
|
||||||
@ -334,7 +326,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
|||||||
if args.adv:
|
if args.adv:
|
||||||
adv_model.eval()
|
adv_model.eval()
|
||||||
|
|
||||||
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
|
||||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
Loading…
Reference in New Issue
Block a user