Add delay to scheduler by adv_delay
This commit is contained in:
parent
0a2fc9a9e9
commit
7f6578c63e
@ -72,7 +72,8 @@ def add_train_args(parser):
|
||||
parser.add_argument('--cgan', action='store_true',
|
||||
help='enable conditional GAN')
|
||||
parser.add_argument('--adv-delay', default=0, type=int,
|
||||
help='epoch before updating the generator with adversarial loss')
|
||||
help='epoch before updating the generator with adversarial loss, '
|
||||
'and the learning rate with scheduler')
|
||||
|
||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||
help='optimizer from torch.optim')
|
||||
|
@ -187,16 +187,21 @@ def gpu_worker(local_rank, args):
|
||||
args)
|
||||
epoch_loss = val_loss
|
||||
|
||||
if epoch >= args.adv_delay:
|
||||
scheduler.step(epoch_loss[0])
|
||||
if args.adv:
|
||||
adv_scheduler.step(epoch_loss[0])
|
||||
else:
|
||||
scheduler.last_epoch = epoch
|
||||
if args.adv:
|
||||
adv_scheduler.last_epoch = epoch
|
||||
|
||||
if args.rank == 0:
|
||||
print(end='', flush=True)
|
||||
args.logger.close()
|
||||
|
||||
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||
if is_best:
|
||||
if is_best and epoch >= args.adv_delay:
|
||||
min_loss = epoch_loss
|
||||
|
||||
state = {
|
||||
|
Loading…
Reference in New Issue
Block a user