From 4cc2fd51ebe300681870e0f52f8fa272766f8453 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sat, 16 May 2020 16:15:42 -0400 Subject: [PATCH] Refactor runtime default argument setting --- map2map/args.py | 44 +++++++++++++++++++++++++++++++++++++++----- map2map/train.py | 24 +++--------------------- setup.py | 2 +- 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/map2map/args.py b/map2map/args.py index 6fba448..a2e83de 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -4,8 +4,11 @@ from .train import ckpt_link def get_args(): + """Parse arguments and set runtime defaults. + """ parser = argparse.ArgumentParser( description='Transform field(s) to field(s)') + subparsers = parser.add_subparsers(title='modes', dest='mode', required=True) train_parser = subparsers.add_parser( 'train', @@ -21,6 +24,11 @@ def get_args(): args = parser.parse_args() + if args.mode == 'train': + set_train_args(args) + elif args.mode == 'test': + set_test_args(args) + return args @@ -52,9 +60,10 @@ def add_common_args(parser): parser.add_argument('--batches', default=1, type=int, help='mini-batch size, per GPU in training or in total in testing') - parser.add_argument('--loader-workers', default=0, type=int, + parser.add_argument('--loader-workers', type=int, help='number of data loading workers, per GPU in training or ' - 'in total in testing') + 'in total in testing. ' + 'Default is the batch size or 0 for batch size 1') parser.add_argument('--cache', action='store_true', help='enable caching in field datasets') @@ -116,9 +125,8 @@ def add_train_args(parser): help='adversary weight decay') parser.add_argument('--reduce-lr-on-plateau', action='store_true', help='Enable ReduceLROnPlateau learning rate scheduler') - parser.add_argument('--init-weight-scale', type=float, - help='weight initialization scale, default is 0.02 with adversary ' - 'and the pytorch default without it') + parser.add_argument('--init-weight-std', type=float, + help='weight initialization std') parser.add_argument('--epochs', default=128, type=int, help='total number of epochs to run') parser.add_argument('--seed', default=42, type=int, @@ -153,3 +161,29 @@ def str_list(s): # elif len(t) != 6: # raise ValueError('size must be int or 6-tuple') # return t + + +def set_common_args(args): + if args.loader_workers is None: + args.loader_workers = 0 + if args.batches > 1: + args.loader_workers = args.batches + + +def set_train_args(args): + set_common_args(args) + + args.val = args.val_in_patterns is not None and \ + args.val_tgt_patterns is not None + + args.adv = args.adv_model is not None + + if args.adv: + if args.adv_lr is None: + args.adv_lr = args.lr + if args.adv_weight_decay is None: + args.adv_weight_decay = args.weight_decay + + +def set_test_args(args): + set_common_args(args) diff --git a/map2map/train.py b/map2map/train.py index 7a5d99e..fe21b73 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -45,8 +45,6 @@ def node_worker(args): def gpu_worker(local_rank, node, args): - set_runtime_default_args(args) - device = torch.device('cuda', local_rank) torch.cuda.device(device) @@ -173,16 +171,16 @@ def gpu_worker(local_rank, node, args): def init_weights(m): if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): - m.weight.data.normal_(0.0, args.init_weight_scale) + m.weight.data.normal_(0.0, args.init_weight_std) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): if m.affine: # NOTE: dispersion from DCGAN, why? - m.weight.data.normal_(1.0, args.init_weight_scale) + m.weight.data.normal_(1.0, args.init_weight_std) m.bias.data.fill_(0) - if args.init_weight_scale is not None: + if args.init_weight_std is not None: model.apply(init_weights) if args.adv: @@ -503,22 +501,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, return epoch_loss -def set_runtime_default_args(args): - args.val = args.val_in_patterns is not None and \ - args.val_tgt_patterns is not None - - args.adv = args.adv_model is not None - - if args.adv: - if args.adv_lr is None: - args.adv_lr = args.lr - if args.adv_weight_decay is None: - args.adv_weight_decay = args.weight_decay - - if args.init_weight_scale is None: - args.init_weight_scale = 0.02 - - def dist_init(rank, args): dist_file = 'dist_addr' diff --git a/setup.py b/setup.py index 409b1e7..9f4e5b5 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( author='Yin Li et al.', author_email='eelregit@gmail.com', packages=find_packages(), - python_requires='>=3', + python_requires='>=3.2', install_requires=[ 'torch', 'numpy',