Refactor runtime default argument setting
This commit is contained in:
parent
9b456d6b1a
commit
4cc2fd51eb
3 changed files with 43 additions and 27 deletions
|
@ -4,8 +4,11 @@ from .train import ckpt_link
|
|||
|
||||
|
||||
def get_args():
|
||||
"""Parse arguments and set runtime defaults.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Transform field(s) to field(s)')
|
||||
|
||||
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
|
||||
train_parser = subparsers.add_parser(
|
||||
'train',
|
||||
|
@ -21,6 +24,11 @@ def get_args():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == 'train':
|
||||
set_train_args(args)
|
||||
elif args.mode == 'test':
|
||||
set_test_args(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
|
@ -52,9 +60,10 @@ def add_common_args(parser):
|
|||
|
||||
parser.add_argument('--batches', default=1, type=int,
|
||||
help='mini-batch size, per GPU in training or in total in testing')
|
||||
parser.add_argument('--loader-workers', default=0, type=int,
|
||||
parser.add_argument('--loader-workers', type=int,
|
||||
help='number of data loading workers, per GPU in training or '
|
||||
'in total in testing')
|
||||
'in total in testing. '
|
||||
'Default is the batch size or 0 for batch size 1')
|
||||
|
||||
parser.add_argument('--cache', action='store_true',
|
||||
help='enable caching in field datasets')
|
||||
|
@ -116,9 +125,8 @@ def add_train_args(parser):
|
|||
help='adversary weight decay')
|
||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
||||
parser.add_argument('--init-weight-scale', type=float,
|
||||
help='weight initialization scale, default is 0.02 with adversary '
|
||||
'and the pytorch default without it')
|
||||
parser.add_argument('--init-weight-std', type=float,
|
||||
help='weight initialization std')
|
||||
parser.add_argument('--epochs', default=128, type=int,
|
||||
help='total number of epochs to run')
|
||||
parser.add_argument('--seed', default=42, type=int,
|
||||
|
@ -153,3 +161,29 @@ def str_list(s):
|
|||
# elif len(t) != 6:
|
||||
# raise ValueError('size must be int or 6-tuple')
|
||||
# return t
|
||||
|
||||
|
||||
def set_common_args(args):
|
||||
if args.loader_workers is None:
|
||||
args.loader_workers = 0
|
||||
if args.batches > 1:
|
||||
args.loader_workers = args.batches
|
||||
|
||||
|
||||
def set_train_args(args):
|
||||
set_common_args(args)
|
||||
|
||||
args.val = args.val_in_patterns is not None and \
|
||||
args.val_tgt_patterns is not None
|
||||
|
||||
args.adv = args.adv_model is not None
|
||||
|
||||
if args.adv:
|
||||
if args.adv_lr is None:
|
||||
args.adv_lr = args.lr
|
||||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
|
||||
def set_test_args(args):
|
||||
set_common_args(args)
|
||||
|
|
|
@ -45,8 +45,6 @@ def node_worker(args):
|
|||
|
||||
|
||||
def gpu_worker(local_rank, node, args):
|
||||
set_runtime_default_args(args)
|
||||
|
||||
device = torch.device('cuda', local_rank)
|
||||
torch.cuda.device(device)
|
||||
|
||||
|
@ -173,16 +171,16 @@ def gpu_worker(local_rank, node, args):
|
|||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
m.weight.data.normal_(0.0, args.init_weight_scale)
|
||||
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||
if m.affine:
|
||||
# NOTE: dispersion from DCGAN, why?
|
||||
m.weight.data.normal_(1.0, args.init_weight_scale)
|
||||
m.weight.data.normal_(1.0, args.init_weight_std)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
if args.init_weight_scale is not None:
|
||||
if args.init_weight_std is not None:
|
||||
model.apply(init_weights)
|
||||
|
||||
if args.adv:
|
||||
|
@ -503,22 +501,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||
return epoch_loss
|
||||
|
||||
|
||||
def set_runtime_default_args(args):
|
||||
args.val = args.val_in_patterns is not None and \
|
||||
args.val_tgt_patterns is not None
|
||||
|
||||
args.adv = args.adv_model is not None
|
||||
|
||||
if args.adv:
|
||||
if args.adv_lr is None:
|
||||
args.adv_lr = args.lr
|
||||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
if args.init_weight_scale is None:
|
||||
args.init_weight_scale = 0.02
|
||||
|
||||
|
||||
def dist_init(rank, args):
|
||||
dist_file = 'dist_addr'
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -8,7 +8,7 @@ setup(
|
|||
author='Yin Li et al.',
|
||||
author_email='eelregit@gmail.com',
|
||||
packages=find_packages(),
|
||||
python_requires='>=3',
|
||||
python_requires='>=3.2',
|
||||
install_requires=[
|
||||
'torch',
|
||||
'numpy',
|
||||
|
|
Loading…
Reference in a new issue