Add delay epoch before updating generator with adversarial loss

This commit is contained in:
Yin Li 2020-01-22 18:44:09 -05:00
parent cdb00ebd8d
commit f99fd6b177
2 changed files with 5 additions and 2 deletions

View File

@ -63,6 +63,8 @@ def add_train_args(parser):
help='adversary criterion from torch.nn')
parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN')
parser.add_argument('--adv-delay', default=0, type=int,
help='epoch before updating the generator with adversarial loss')
parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim')

View File

@ -288,8 +288,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
loss_adv = adv_criterion(eval_out, real) # FIXME try min
epoch_loss[1] += loss_adv.item()
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
if epoch >= args.adv_delay:
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
optimizer.zero_grad()
loss.backward()