Add delay epoch before updating generator with adversarial loss
This commit is contained in:
parent
cdb00ebd8d
commit
f99fd6b177
@ -63,6 +63,8 @@ def add_train_args(parser):
|
|||||||
help='adversary criterion from torch.nn')
|
help='adversary criterion from torch.nn')
|
||||||
parser.add_argument('--cgan', action='store_true',
|
parser.add_argument('--cgan', action='store_true',
|
||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
|
parser.add_argument('--adv-delay', default=0, type=int,
|
||||||
|
help='epoch before updating the generator with adversarial loss')
|
||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
|
@ -288,8 +288,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||||
epoch_loss[1] += loss_adv.item()
|
epoch_loss[1] += loss_adv.item()
|
||||||
|
|
||||||
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
|
if epoch >= args.adv_delay:
|
||||||
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
|
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
|
||||||
|
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
Loading…
Reference in New Issue
Block a user