Add optimizer and scheduler kwargs via json.loads
This commit is contained in:
parent
8ce13e67f6
commit
2336d83f2d
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .train import ckpt_link
|
from .train import ckpt_link
|
||||||
@ -110,12 +111,14 @@ def add_train_args(parser):
|
|||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
parser.add_argument('--lr', type=float, required=True,
|
parser.add_argument('--lr', type=float, required=True,
|
||||||
help='initial learning rate')
|
help='initial learning rate')
|
||||||
# parser.add_argument('--momentum', default=0.9, type=float,
|
parser.add_argument('--optimizer-args', type=json.loads,
|
||||||
# help='momentum')
|
help='optimizer arguments in addition to the learning rate, '
|
||||||
parser.add_argument('--weight-decay', default=0, type=float,
|
'e.g. --optimizer-args \'{"betas": [0.5, 0.9]}\'')
|
||||||
help='weight decay')
|
|
||||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
||||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
help='Enable ReduceLROnPlateau learning rate scheduler')
|
||||||
|
parser.add_argument('--scheduler-args', default='{"verbose": true}',
|
||||||
|
type=json.loads,
|
||||||
|
help='arguments for the ReduceLROnPlateau scheduler')
|
||||||
parser.add_argument('--init-weight-std', type=float,
|
parser.add_argument('--init-weight-std', type=float,
|
||||||
help='weight initialization std')
|
help='weight initialization std')
|
||||||
parser.add_argument('--epochs', default=128, type=int,
|
parser.add_argument('--epochs', default=128, type=int,
|
||||||
|
@ -135,20 +135,17 @@ def gpu_worker(local_rank, node, args):
|
|||||||
lag_optimizer = optimizer(
|
lag_optimizer = optimizer(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
#momentum=args.momentum,
|
**args.optimizer_args,
|
||||||
betas=(0.9, 0.999),
|
|
||||||
weight_decay=args.weight_decay,
|
|
||||||
)
|
)
|
||||||
eul_optimizer = optimizer(
|
eul_optimizer = optimizer(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
betas=(0.9, 0.999),
|
**args.optimizer_args,
|
||||||
weight_decay=args.weight_decay,
|
|
||||||
)
|
)
|
||||||
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(lag_optimizer,
|
lag_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
factor=0.1, patience=10, verbose=True)
|
lag_optimizer, **args.scheduler_args)
|
||||||
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(eul_optimizer,
|
eul_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
factor=0.1, patience=10, verbose=True)
|
eul_optimizer, **args.scheduler_args)
|
||||||
|
|
||||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||||
or not args.load_state):
|
or not args.load_state):
|
||||||
|
Loading…
Reference in New Issue
Block a user