Add optimizer and scheduler kwargs via json.loads

This commit is contained in:
Yin Li 2020-07-27 12:40:19 -07:00
parent 8ce13e67f6
commit 2336d83f2d
2 changed files with 13 additions and 13 deletions

View File

@ -1,5 +1,6 @@
import os
import argparse
import json
import warnings
from .train import ckpt_link
@ -110,12 +111,14 @@ def add_train_args(parser):
help='optimizer from torch.optim')
parser.add_argument('--lr', type=float, required=True,
help='initial learning rate')
# parser.add_argument('--momentum', default=0.9, type=float,
# help='momentum')
parser.add_argument('--weight-decay', default=0, type=float,
help='weight decay')
parser.add_argument('--optimizer-args', type=json.loads,
help='optimizer arguments in addition to the learning rate, '
'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'')
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--scheduler-args', default='{"verbose": true}',
type=json.loads,
help='arguments for the ReduceLROnPlateau scheduler')
parser.add_argument('--init-weight-std', type=float,
help='weight initialization std')
parser.add_argument('--epochs', default=128, type=int,

View File

@ -135,20 +135,17 @@ def gpu_worker(local_rank, node, args):
lag_optimizer = optimizer(
model.parameters(),
lr=args.lr,
#momentum=args.momentum,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
**args.optimizer_args,
)
eul_optimizer = optimizer(
model.parameters(),
lr=args.lr,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
**args.optimizer_args,
)
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(lag_optimizer,
factor=0.1, patience=10, verbose=True)
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(eul_optimizer,
factor=0.1, patience=10, verbose=True)
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
lag_optimizer, **args.scheduler_args)
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
eul_optimizer, **args.scheduler_args)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):