map2map/map2map/args.py
2020-07-15 22:03:30 -04:00

169 lines
6.7 KiB
Python

import os
import argparse
import warnings
from .train import ckpt_link
def get_args():
"""Parse arguments and set runtime defaults.
"""
parser = argparse.ArgumentParser(
description='Transform field(s) to field(s)')
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
train_parser = subparsers.add_parser(
'train',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
test_parser = subparsers.add_parser(
'test',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
add_train_args(train_parser)
add_test_args(test_parser)
args = parser.parse_args()
if args.mode == 'train':
set_train_args(args)
elif args.mode == 'test':
set_test_args(args)
return args
def add_common_args(parser):
parser.add_argument('--in-norms', type=str_list, help='comma-sep. list '
'of input normalization functions from .data.norms')
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
'of target normalization functions from .data.norms')
parser.add_argument('--crop', type=int,
help='size to crop the input and target data. Default is the '
'field size')
parser.add_argument('--crop-start', type=int,
help='starting point of the first crop. Default is the origin')
parser.add_argument('--crop-stop', type=int,
help='stopping point of the last crop. Default is the opposite '
'corner to the origin')
parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size')
parser.add_argument('--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--scale-factor', default=1, type=int,
help='upsampling factor for super-resolution, in which case '
'crop and pad are sizes of the input resolution')
parser.add_argument('--model', type=str, required=True,
help='model from .models')
parser.add_argument('--criterion', default='MSELoss', type=str,
help='model criterion from torch.nn')
parser.add_argument('--load-state', default=ckpt_link, type=str,
help='path to load the states of model, optimizer, rng, etc. '
'Default is the checkpoint. '
'Start from scratch in case of empty string or missing checkpoint')
parser.add_argument('--load-state-non-strict', action='store_false',
help='allow incompatible keys when loading model states',
dest='load_state_strict')
parser.add_argument('--batches', type=int, required=True,
help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=-2, type=int,
help='number of subprocesses per data loader. '
'0 to disable multiprocessing; '
'negative number to multiply by the batch size')
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. '
'This is appended to the default locations, '
'thus has the lowest priority.')
def add_train_args(parser):
add_common_args(parser)
parser.add_argument('--train-in-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training input data')
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training target data')
parser.add_argument('--val-in-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation input data')
parser.add_argument('--val-tgt-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true',
help='enable data augmentation of axis flipping and permutation')
parser.add_argument('--aug-shift', type=int,
help='data augmentation by shifting [0, aug_shift) pixels, '
'useful for models that treat neighboring pixels differently, '
'e.g. with strided convolutions')
parser.add_argument('--aug-add', type=float,
help='additive data augmentation, (normal) std, '
'same factor for all fields')
parser.add_argument('--aug-mul', type=float,
help='multiplicative data augmentation, (log-normal) std, '
'same factor for all fields')
parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim')
parser.add_argument('--lr', type=float, required=True,
help='initial learning rate')
# parser.add_argument('--momentum', default=0.9, type=float,
# help='momentum')
parser.add_argument('--weight-decay', default=0, type=float,
help='weight decay')
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--init-weight-std', type=float,
help='weight initialization std')
parser.add_argument('--epochs', default=128, type=int,
help='total number of epochs to run')
parser.add_argument('--seed', default=42, type=int,
help='seed for initializing training')
parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=100, type=int,
help='interval between logging training loss')
def add_test_args(parser):
add_common_args(parser)
parser.add_argument('--test-in-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for test input data')
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for test target data')
def str_list(s):
return s.split(',')
#def int_tuple(t):
# t = t.split(',')
# t = tuple(int(i) for i in t)
# if len(t) == 1:
# t = t[0]
# elif len(t) != 6:
# raise ValueError('size must be int or 6-tuple')
# return t
def set_common_args(args):
if args.loader_workers < 0:
args.loader_workers *= - args.batches
def set_train_args(args):
set_common_args(args)
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
def set_test_args(args):
set_common_args(args)