Refactor runtime default argument setting
This commit is contained in:
parent
9b456d6b1a
commit
4cc2fd51eb
@ -4,8 +4,11 @@ from .train import ckpt_link
|
|||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
"""Parse arguments and set runtime defaults.
|
||||||
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Transform field(s) to field(s)')
|
description='Transform field(s) to field(s)')
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
|
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
|
||||||
train_parser = subparsers.add_parser(
|
train_parser = subparsers.add_parser(
|
||||||
'train',
|
'train',
|
||||||
@ -21,6 +24,11 @@ def get_args():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == 'train':
|
||||||
|
set_train_args(args)
|
||||||
|
elif args.mode == 'test':
|
||||||
|
set_test_args(args)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@ -52,9 +60,10 @@ def add_common_args(parser):
|
|||||||
|
|
||||||
parser.add_argument('--batches', default=1, type=int,
|
parser.add_argument('--batches', default=1, type=int,
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
parser.add_argument('--loader-workers', default=0, type=int,
|
parser.add_argument('--loader-workers', type=int,
|
||||||
help='number of data loading workers, per GPU in training or '
|
help='number of data loading workers, per GPU in training or '
|
||||||
'in total in testing')
|
'in total in testing. '
|
||||||
|
'Default is the batch size or 0 for batch size 1')
|
||||||
|
|
||||||
parser.add_argument('--cache', action='store_true',
|
parser.add_argument('--cache', action='store_true',
|
||||||
help='enable caching in field datasets')
|
help='enable caching in field datasets')
|
||||||
@ -116,9 +125,8 @@ def add_train_args(parser):
|
|||||||
help='adversary weight decay')
|
help='adversary weight decay')
|
||||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
||||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
help='Enable ReduceLROnPlateau learning rate scheduler')
|
||||||
parser.add_argument('--init-weight-scale', type=float,
|
parser.add_argument('--init-weight-std', type=float,
|
||||||
help='weight initialization scale, default is 0.02 with adversary '
|
help='weight initialization std')
|
||||||
'and the pytorch default without it')
|
|
||||||
parser.add_argument('--epochs', default=128, type=int,
|
parser.add_argument('--epochs', default=128, type=int,
|
||||||
help='total number of epochs to run')
|
help='total number of epochs to run')
|
||||||
parser.add_argument('--seed', default=42, type=int,
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
@ -153,3 +161,29 @@ def str_list(s):
|
|||||||
# elif len(t) != 6:
|
# elif len(t) != 6:
|
||||||
# raise ValueError('size must be int or 6-tuple')
|
# raise ValueError('size must be int or 6-tuple')
|
||||||
# return t
|
# return t
|
||||||
|
|
||||||
|
|
||||||
|
def set_common_args(args):
|
||||||
|
if args.loader_workers is None:
|
||||||
|
args.loader_workers = 0
|
||||||
|
if args.batches > 1:
|
||||||
|
args.loader_workers = args.batches
|
||||||
|
|
||||||
|
|
||||||
|
def set_train_args(args):
|
||||||
|
set_common_args(args)
|
||||||
|
|
||||||
|
args.val = args.val_in_patterns is not None and \
|
||||||
|
args.val_tgt_patterns is not None
|
||||||
|
|
||||||
|
args.adv = args.adv_model is not None
|
||||||
|
|
||||||
|
if args.adv:
|
||||||
|
if args.adv_lr is None:
|
||||||
|
args.adv_lr = args.lr
|
||||||
|
if args.adv_weight_decay is None:
|
||||||
|
args.adv_weight_decay = args.weight_decay
|
||||||
|
|
||||||
|
|
||||||
|
def set_test_args(args):
|
||||||
|
set_common_args(args)
|
||||||
|
@ -45,8 +45,6 @@ def node_worker(args):
|
|||||||
|
|
||||||
|
|
||||||
def gpu_worker(local_rank, node, args):
|
def gpu_worker(local_rank, node, args):
|
||||||
set_runtime_default_args(args)
|
|
||||||
|
|
||||||
device = torch.device('cuda', local_rank)
|
device = torch.device('cuda', local_rank)
|
||||||
torch.cuda.device(device)
|
torch.cuda.device(device)
|
||||||
|
|
||||||
@ -173,16 +171,16 @@ def gpu_worker(local_rank, node, args):
|
|||||||
def init_weights(m):
|
def init_weights(m):
|
||||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||||
m.weight.data.normal_(0.0, args.init_weight_scale)
|
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||||
if m.affine:
|
if m.affine:
|
||||||
# NOTE: dispersion from DCGAN, why?
|
# NOTE: dispersion from DCGAN, why?
|
||||||
m.weight.data.normal_(1.0, args.init_weight_scale)
|
m.weight.data.normal_(1.0, args.init_weight_std)
|
||||||
m.bias.data.fill_(0)
|
m.bias.data.fill_(0)
|
||||||
|
|
||||||
if args.init_weight_scale is not None:
|
if args.init_weight_std is not None:
|
||||||
model.apply(init_weights)
|
model.apply(init_weights)
|
||||||
|
|
||||||
if args.adv:
|
if args.adv:
|
||||||
@ -503,22 +501,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
def set_runtime_default_args(args):
|
|
||||||
args.val = args.val_in_patterns is not None and \
|
|
||||||
args.val_tgt_patterns is not None
|
|
||||||
|
|
||||||
args.adv = args.adv_model is not None
|
|
||||||
|
|
||||||
if args.adv:
|
|
||||||
if args.adv_lr is None:
|
|
||||||
args.adv_lr = args.lr
|
|
||||||
if args.adv_weight_decay is None:
|
|
||||||
args.adv_weight_decay = args.weight_decay
|
|
||||||
|
|
||||||
if args.init_weight_scale is None:
|
|
||||||
args.init_weight_scale = 0.02
|
|
||||||
|
|
||||||
|
|
||||||
def dist_init(rank, args):
|
def dist_init(rank, args):
|
||||||
dist_file = 'dist_addr'
|
dist_file = 'dist_addr'
|
||||||
|
|
||||||
|
2
setup.py
2
setup.py
@ -8,7 +8,7 @@ setup(
|
|||||||
author='Yin Li et al.',
|
author='Yin Li et al.',
|
||||||
author_email='eelregit@gmail.com',
|
author_email='eelregit@gmail.com',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
python_requires='>=3',
|
python_requires='>=3.2',
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'torch',
|
'torch',
|
||||||
'numpy',
|
'numpy',
|
||||||
|
Loading…
Reference in New Issue
Block a user