Add argument to disable ReduceLROnPlateau by default

This commit is contained in:
Yin Li 2020-02-09 20:15:23 -05:00
parent a2e3a9dc64
commit 2aed00d97e
2 changed files with 11 additions and 8 deletions

View File

@ -90,6 +90,8 @@ def add_train_args(parser):
help='initial adversary learning rate') help='initial adversary learning rate')
parser.add_argument('--adv-weight-decay', type=float, parser.add_argument('--adv-weight-decay', type=float,
help='adversary weight decay') help='adversary weight decay')
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--epochs', default=128, type=int, parser.add_argument('--epochs', default=128, type=int,
help='total number of epochs to run') help='total number of epochs to run')
parser.add_argument('--seed', default=42, type=int, parser.add_argument('--seed', default=42, type=int,

View File

@ -210,14 +210,15 @@ def gpu_worker(local_rank, args):
args) args)
epoch_loss = val_loss epoch_loss = val_loss
if epoch >= args.adv_delay: if args.reduce_lr_on_plateau:
scheduler.step(epoch_loss[0]) if epoch >= args.adv_delay:
if args.adv: scheduler.step(epoch_loss[0])
adv_scheduler.step(epoch_loss[0]) if args.adv:
else: adv_scheduler.step(epoch_loss[0])
scheduler.last_epoch = epoch else:
if args.adv: scheduler.last_epoch = epoch
adv_scheduler.last_epoch = epoch if args.adv:
adv_scheduler.last_epoch = epoch
if args.rank == 0: if args.rank == 0:
print(end='', flush=True) print(end='', flush=True)