From 7f6578c63eb5d5e62f46c19e6e06d4d7262001b6 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Mon, 3 Feb 2020 14:44:14 -0500 Subject: [PATCH] Add delay to scheduler by adv_delay --- map2map/args.py | 3 ++- map2map/train.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/map2map/args.py b/map2map/args.py index 36d073e..dde40da 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -72,7 +72,8 @@ def add_train_args(parser): parser.add_argument('--cgan', action='store_true', help='enable conditional GAN') parser.add_argument('--adv-delay', default=0, type=int, - help='epoch before updating the generator with adversarial loss') + help='epoch before updating the generator with adversarial loss, ' + 'and the learning rate with scheduler') parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer from torch.optim') diff --git a/map2map/train.py b/map2map/train.py index 7705ed4..ba1ae30 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -187,16 +187,21 @@ def gpu_worker(local_rank, args): args) epoch_loss = val_loss - scheduler.step(epoch_loss[0]) - if args.adv: - adv_scheduler.step(epoch_loss[0]) + if epoch >= args.adv_delay: + scheduler.step(epoch_loss[0]) + if args.adv: + adv_scheduler.step(epoch_loss[0]) + else: + scheduler.last_epoch = epoch + if args.adv: + adv_scheduler.last_epoch = epoch if args.rank == 0: print(end='', flush=True) args.logger.close() is_best = min_loss is None or epoch_loss[0] < min_loss[0] - if is_best: + if is_best and epoch >= args.adv_delay: min_loss = epoch_loss state = {