Add delay to scheduler by adv_delay

This commit is contained in:
Yin Li 2020-02-03 14:44:14 -05:00
parent 0a2fc9a9e9
commit 7f6578c63e
2 changed files with 11 additions and 5 deletions

View File

@ -72,7 +72,8 @@ def add_train_args(parser):
parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN')
parser.add_argument('--adv-delay', default=0, type=int,
help='epoch before updating the generator with adversarial loss')
help='epoch before updating the generator with adversarial loss, '
'and the learning rate with scheduler')
parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim')

View File

@ -187,16 +187,21 @@ def gpu_worker(local_rank, args):
args)
epoch_loss = val_loss
scheduler.step(epoch_loss[0])
if args.adv:
adv_scheduler.step(epoch_loss[0])
if epoch >= args.adv_delay:
scheduler.step(epoch_loss[0])
if args.adv:
adv_scheduler.step(epoch_loss[0])
else:
scheduler.last_epoch = epoch
if args.adv:
adv_scheduler.last_epoch = epoch
if args.rank == 0:
print(end='', flush=True)
args.logger.close()
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
if is_best:
if is_best and epoch >= args.adv_delay:
min_loss = epoch_loss
state = {