Refactor runtime default argument setting

This commit is contained in:
Yin Li 2020-05-16 16:15:42 -04:00
parent 9b456d6b1a
commit 4cc2fd51eb
3 changed files with 43 additions and 27 deletions

View file

@ -4,8 +4,11 @@ from .train import ckpt_link
def get_args():
"""Parse arguments and set runtime defaults.
"""
parser = argparse.ArgumentParser(
description='Transform field(s) to field(s)')
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
train_parser = subparsers.add_parser(
'train',
@ -21,6 +24,11 @@ def get_args():
args = parser.parse_args()
if args.mode == 'train':
set_train_args(args)
elif args.mode == 'test':
set_test_args(args)
return args
@ -52,9 +60,10 @@ def add_common_args(parser):
parser.add_argument('--batches', default=1, type=int,
help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=0, type=int,
parser.add_argument('--loader-workers', type=int,
help='number of data loading workers, per GPU in training or '
'in total in testing')
'in total in testing. '
'Default is the batch size or 0 for batch size 1')
parser.add_argument('--cache', action='store_true',
help='enable caching in field datasets')
@ -116,9 +125,8 @@ def add_train_args(parser):
help='adversary weight decay')
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--init-weight-scale', type=float,
help='weight initialization scale, default is 0.02 with adversary '
'and the pytorch default without it')
parser.add_argument('--init-weight-std', type=float,
help='weight initialization std')
parser.add_argument('--epochs', default=128, type=int,
help='total number of epochs to run')
parser.add_argument('--seed', default=42, type=int,
@ -153,3 +161,29 @@ def str_list(s):
# elif len(t) != 6:
# raise ValueError('size must be int or 6-tuple')
# return t
def set_common_args(args):
if args.loader_workers is None:
args.loader_workers = 0
if args.batches > 1:
args.loader_workers = args.batches
def set_train_args(args):
set_common_args(args)
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
args.adv = args.adv_model is not None
if args.adv:
if args.adv_lr is None:
args.adv_lr = args.lr
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
def set_test_args(args):
set_common_args(args)

View file

@ -45,8 +45,6 @@ def node_worker(args):
def gpu_worker(local_rank, node, args):
set_runtime_default_args(args)
device = torch.device('cuda', local_rank)
torch.cuda.device(device)
@ -173,16 +171,16 @@ def gpu_worker(local_rank, node, args):
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_scale)
m.weight.data.normal_(0.0, args.init_weight_std)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_scale)
m.weight.data.normal_(1.0, args.init_weight_std)
m.bias.data.fill_(0)
if args.init_weight_scale is not None:
if args.init_weight_std is not None:
model.apply(init_weights)
if args.adv:
@ -503,22 +501,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
return epoch_loss
def set_runtime_default_args(args):
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
args.adv = args.adv_model is not None
if args.adv:
if args.adv_lr is None:
args.adv_lr = args.lr
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
if args.init_weight_scale is None:
args.init_weight_scale = 0.02
def dist_init(rank, args):
dist_file = 'dist_addr'

View file

@ -8,7 +8,7 @@ setup(
author='Yin Li et al.',
author_email='eelregit@gmail.com',
packages=find_packages(),
python_requires='>=3',
python_requires='>=3.2',
install_requires=[
'torch',
'numpy',