Add optimizer and scheduler kwargs via json.loads

This commit is contained in:
Yin Li 2020-07-27 12:40:19 -07:00
parent 8ce13e67f6
commit 2336d83f2d
2 changed files with 13 additions and 13 deletions

View File

@ -1,5 +1,6 @@
import os import os
import argparse import argparse
import json
import warnings import warnings
from .train import ckpt_link from .train import ckpt_link
@ -110,12 +111,14 @@ def add_train_args(parser):
help='optimizer from torch.optim') help='optimizer from torch.optim')
parser.add_argument('--lr', type=float, required=True, parser.add_argument('--lr', type=float, required=True,
help='initial learning rate') help='initial learning rate')
# parser.add_argument('--momentum', default=0.9, type=float, parser.add_argument('--optimizer-args', type=json.loads,
# help='momentum') help='optimizer arguments in addition to the learning rate, '
parser.add_argument('--weight-decay', default=0, type=float, 'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'')
help='weight decay')
parser.add_argument('--reduce-lr-on-plateau', action='store_true', parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler') help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--scheduler-args', default='{"verbose": true}',
type=json.loads,
help='arguments for the ReduceLROnPlateau scheduler')
parser.add_argument('--init-weight-std', type=float, parser.add_argument('--init-weight-std', type=float,
help='weight initialization std') help='weight initialization std')
parser.add_argument('--epochs', default=128, type=int, parser.add_argument('--epochs', default=128, type=int,

View File

@ -135,20 +135,17 @@ def gpu_worker(local_rank, node, args):
lag_optimizer = optimizer( lag_optimizer = optimizer(
model.parameters(), model.parameters(),
lr=args.lr, lr=args.lr,
#momentum=args.momentum, **args.optimizer_args,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
) )
eul_optimizer = optimizer( eul_optimizer = optimizer(
model.parameters(), model.parameters(),
lr=args.lr, lr=args.lr,
betas=(0.9, 0.999), **args.optimizer_args,
weight_decay=args.weight_decay,
) )
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(lag_optimizer, lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
factor=0.1, patience=10, verbose=True) lag_optimizer, **args.scheduler_args)
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(eul_optimizer, eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
factor=0.1, patience=10, verbose=True) eul_optimizer, **args.scheduler_args)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link) if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state): or not args.load_state):