Add delay epoch before updating generator with adversarial loss
This commit is contained in:
parent
cdb00ebd8d
commit
f99fd6b177
@ -63,6 +63,8 @@ def add_train_args(parser):
|
||||
help='adversary criterion from torch.nn')
|
||||
parser.add_argument('--cgan', action='store_true',
|
||||
help='enable conditional GAN')
|
||||
parser.add_argument('--adv-delay', default=0, type=int,
|
||||
help='epoch before updating the generator with adversarial loss')
|
||||
|
||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||
help='optimizer from torch.optim')
|
||||
|
@ -288,8 +288,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
|
||||
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
|
||||
if epoch >= args.adv_delay:
|
||||
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
|
||||
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
Loading…
Reference in New Issue
Block a user