Disable D grad when updating G; Revert D and G order
The reason for the latter is to follow most code but I haven't found the reason
This commit is contained in:
parent
aeeeb966d8
commit
db3414e11c
1 changed files with 26 additions and 16 deletions
|
@ -296,10 +296,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
# generator adversarial loss
|
||||
# discriminator
|
||||
if args.adv:
|
||||
if args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
|
@ -309,6 +307,26 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
set_requires_grad(adv_model, True)
|
||||
|
||||
eval = adv_model([output.detach(), target])
|
||||
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
adv_optimizer.zero_grad()
|
||||
adv_loss.backward()
|
||||
adv_optimizer.step()
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
# generator adversarial loss
|
||||
if args.adv:
|
||||
set_requires_grad(adv_model, False)
|
||||
|
||||
eval_out = adv_model(output)
|
||||
loss_adv, = adv_criterion(eval_out, real)
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
@ -323,19 +341,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# discriminator
|
||||
if args.adv:
|
||||
eval = adv_model([output.detach(), target])
|
||||
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
adv_optimizer.zero_grad()
|
||||
adv_loss.backward()
|
||||
adv_optimizer.step()
|
||||
|
||||
batch = epoch * len(loader) + i + 1
|
||||
if batch % args.log_interval == 0:
|
||||
dist.all_reduce(loss)
|
||||
|
@ -448,3 +453,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||
global_step =epoch+1)
|
||||
|
||||
return epoch_loss
|
||||
|
||||
|
||||
def set_requires_grad(module, requires_grad=False):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
|
Loading…
Reference in a new issue