Disable D grad when updating G; Revert D and G order

The reason for the latter is to follow most code but I haven't
found the reason
This commit is contained in:
Yin Li 2020-02-13 11:33:40 -05:00
parent aeeeb966d8
commit db3414e11c

View File

@ -296,10 +296,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
input = input[:, :-args.noise_chan] # remove noise channels input = input[:, :-args.noise_chan] # remove noise channels
target = narrow_like(target, output) # FIXME pad target = narrow_like(target, output) # FIXME pad
loss = criterion(output, target)
epoch_loss[0] += loss.item()
# generator adversarial loss # discriminator
if args.adv: if args.adv:
if args.cgan: if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
@ -309,6 +307,26 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
output = torch.cat([input, output], dim=1) output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1) target = torch.cat([input, target], dim=1)
set_requires_grad(adv_model, True)
eval = adv_model([output.detach(), target])
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
epoch_loss[3] += adv_loss_fake.item()
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item()
adv_optimizer.zero_grad()
adv_loss.backward()
adv_optimizer.step()
loss = criterion(output, target)
epoch_loss[0] += loss.item()
# generator adversarial loss
if args.adv:
set_requires_grad(adv_model, False)
eval_out = adv_model(output) eval_out = adv_model(output)
loss_adv, = adv_criterion(eval_out, real) loss_adv, = adv_criterion(eval_out, real)
epoch_loss[1] += loss_adv.item() epoch_loss[1] += loss_adv.item()
@ -323,19 +341,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# discriminator
if args.adv:
eval = adv_model([output.detach(), target])
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
epoch_loss[3] += adv_loss_fake.item()
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item()
adv_optimizer.zero_grad()
adv_loss.backward()
adv_optimizer.step()
batch = epoch * len(loader) + i + 1 batch = epoch * len(loader) + i + 1
if batch % args.log_interval == 0: if batch % args.log_interval == 0:
dist.all_reduce(loss) dist.all_reduce(loss)
@ -448,3 +453,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
global_step =epoch+1) global_step =epoch+1)
return epoch_loss return epoch_loss
def set_requires_grad(module, requires_grad=False):
for param in module.parameters():
param.requires_grad = requires_grad