Add optimizer and scheduler kwargs via json.loads
This commit is contained in:
parent
8ce13e67f6
commit
2336d83f2d
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import warnings
|
||||
|
||||
from .train import ckpt_link
|
||||
@ -110,12 +111,14 @@ def add_train_args(parser):
|
||||
help='optimizer from torch.optim')
|
||||
parser.add_argument('--lr', type=float, required=True,
|
||||
help='initial learning rate')
|
||||
# parser.add_argument('--momentum', default=0.9, type=float,
|
||||
# help='momentum')
|
||||
parser.add_argument('--weight-decay', default=0, type=float,
|
||||
help='weight decay')
|
||||
parser.add_argument('--optimizer-args', type=json.loads,
|
||||
help='optimizer arguments in addition to the learning rate, '
|
||||
'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'')
|
||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
||||
parser.add_argument('--scheduler-args', default='{"verbose": true}',
|
||||
type=json.loads,
|
||||
help='arguments for the ReduceLROnPlateau scheduler')
|
||||
parser.add_argument('--init-weight-std', type=float,
|
||||
help='weight initialization std')
|
||||
parser.add_argument('--epochs', default=128, type=int,
|
||||
|
@ -135,20 +135,17 @@ def gpu_worker(local_rank, node, args):
|
||||
lag_optimizer = optimizer(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
#momentum=args.momentum,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=args.weight_decay,
|
||||
**args.optimizer_args,
|
||||
)
|
||||
eul_optimizer = optimizer(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=args.weight_decay,
|
||||
**args.optimizer_args,
|
||||
)
|
||||
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(lag_optimizer,
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(eul_optimizer,
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
lag_optimizer, **args.scheduler_args)
|
||||
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
eul_optimizer, **args.scheduler_args)
|
||||
|
||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||
or not args.load_state):
|
||||
|
Loading…
Reference in New Issue
Block a user