Disable D grad when updating G; Revert D and G order
The reason for the latter is to follow most code but I haven't found the reason
This commit is contained in:
parent
aeeeb966d8
commit
db3414e11c
@ -296,10 +296,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
target = narrow_like(target, output) # FIXME pad
|
||||||
loss = criterion(output, target)
|
|
||||||
epoch_loss[0] += loss.item()
|
|
||||||
|
|
||||||
# generator adversarial loss
|
# discriminator
|
||||||
if args.adv:
|
if args.adv:
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
@ -309,6 +307,26 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
output = torch.cat([input, output], dim=1)
|
output = torch.cat([input, output], dim=1)
|
||||||
target = torch.cat([input, target], dim=1)
|
target = torch.cat([input, target], dim=1)
|
||||||
|
|
||||||
|
set_requires_grad(adv_model, True)
|
||||||
|
|
||||||
|
eval = adv_model([output.detach(), target])
|
||||||
|
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
|
||||||
|
epoch_loss[3] += adv_loss_fake.item()
|
||||||
|
epoch_loss[4] += adv_loss_real.item()
|
||||||
|
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||||
|
epoch_loss[2] += adv_loss.item()
|
||||||
|
|
||||||
|
adv_optimizer.zero_grad()
|
||||||
|
adv_loss.backward()
|
||||||
|
adv_optimizer.step()
|
||||||
|
|
||||||
|
loss = criterion(output, target)
|
||||||
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
|
# generator adversarial loss
|
||||||
|
if args.adv:
|
||||||
|
set_requires_grad(adv_model, False)
|
||||||
|
|
||||||
eval_out = adv_model(output)
|
eval_out = adv_model(output)
|
||||||
loss_adv, = adv_criterion(eval_out, real)
|
loss_adv, = adv_criterion(eval_out, real)
|
||||||
epoch_loss[1] += loss_adv.item()
|
epoch_loss[1] += loss_adv.item()
|
||||||
@ -323,19 +341,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# discriminator
|
|
||||||
if args.adv:
|
|
||||||
eval = adv_model([output.detach(), target])
|
|
||||||
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real])
|
|
||||||
epoch_loss[3] += adv_loss_fake.item()
|
|
||||||
epoch_loss[4] += adv_loss_real.item()
|
|
||||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
|
||||||
epoch_loss[2] += adv_loss.item()
|
|
||||||
|
|
||||||
adv_optimizer.zero_grad()
|
|
||||||
adv_loss.backward()
|
|
||||||
adv_optimizer.step()
|
|
||||||
|
|
||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
if batch % args.log_interval == 0:
|
if batch % args.log_interval == 0:
|
||||||
dist.all_reduce(loss)
|
dist.all_reduce(loss)
|
||||||
@ -448,3 +453,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
global_step =epoch+1)
|
global_step =epoch+1)
|
||||||
|
|
||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
|
def set_requires_grad(module, requires_grad=False):
|
||||||
|
for param in module.parameters():
|
||||||
|
param.requires_grad = requires_grad
|
||||||
|
Loading…
Reference in New Issue
Block a user