189 lines
7.6 KiB
Python
189 lines
7.6 KiB
Python
import argparse
|
|
|
|
from .train import ckpt_link
|
|
|
|
|
|
def get_args():
|
|
"""Parse arguments and set runtime defaults.
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
description='Transform field(s) to field(s)')
|
|
|
|
subparsers = parser.add_subparsers(title='modes', dest='mode', required=True)
|
|
train_parser = subparsers.add_parser(
|
|
'train',
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
test_parser = subparsers.add_parser(
|
|
'test',
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
add_train_args(train_parser)
|
|
add_test_args(test_parser)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.mode == 'train':
|
|
set_train_args(args)
|
|
elif args.mode == 'test':
|
|
set_test_args(args)
|
|
|
|
return args
|
|
|
|
|
|
def add_common_args(parser):
|
|
parser.add_argument('--in-norms', type=str_list, help='comma-sep. list '
|
|
'of input normalization functions from .data.norms')
|
|
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
|
'of target normalization functions from .data.norms')
|
|
parser.add_argument('--crop', type=int,
|
|
help='size to crop the input and target data')
|
|
parser.add_argument('--pad', default=0, type=int,
|
|
help='size to pad the input data beyond the crop size, assuming '
|
|
'periodic boundary condition')
|
|
parser.add_argument('--scale-factor', default=1, type=int,
|
|
help='input upsampling factor for super-resolution purpose, in '
|
|
'which case crop and pad will be taken at the original resolution')
|
|
|
|
parser.add_argument('--model', required=True, type=str,
|
|
help='model from .models')
|
|
parser.add_argument('--criterion', default='MSELoss', type=str,
|
|
help='model criterion from torch.nn')
|
|
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
|
help='path to load the states of model, optimizer, rng, etc. '
|
|
'Default is the checkpoint. '
|
|
'Start from scratch if the checkpoint does not exist')
|
|
parser.add_argument('--load-state-non-strict', action='store_false',
|
|
help='allow incompatible keys when loading model states',
|
|
dest='load_state_strict')
|
|
|
|
parser.add_argument('--batches', default=1, type=int,
|
|
help='mini-batch size, per GPU in training or in total in testing')
|
|
parser.add_argument('--loader-workers', type=int,
|
|
help='number of data loading workers, per GPU in training or '
|
|
'in total in testing. '
|
|
'Default is the batch size or 0 for batch size 1')
|
|
|
|
parser.add_argument('--cache', action='store_true',
|
|
help='enable caching in field datasets')
|
|
|
|
|
|
def add_train_args(parser):
|
|
add_common_args(parser)
|
|
|
|
parser.add_argument('--train-in-patterns', type=str_list, required=True,
|
|
help='comma-sep. list of glob patterns for training input data')
|
|
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
|
help='comma-sep. list of glob patterns for training target data')
|
|
parser.add_argument('--val-in-patterns', type=str_list,
|
|
help='comma-sep. list of glob patterns for validation input data')
|
|
parser.add_argument('--val-tgt-patterns', type=str_list,
|
|
help='comma-sep. list of glob patterns for validation target data')
|
|
parser.add_argument('--augment', action='store_true',
|
|
help='enable data augmentation of axis flipping and permutation')
|
|
parser.add_argument('--aug-add', type=float,
|
|
help='additive data augmentation, (normal) std, '
|
|
'same factor for all fields')
|
|
parser.add_argument('--aug-mul', type=float,
|
|
help='multiplicative data augmentation, (log-normal) std, '
|
|
'same factor for all fields')
|
|
|
|
parser.add_argument('--adv-model', type=str,
|
|
help='enable adversary model from .models')
|
|
parser.add_argument('--adv-model-spectral-norm', action='store_true',
|
|
help='enable spectral normalization on the adversary model')
|
|
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
|
help='adversarial criterion from torch.nn')
|
|
parser.add_argument('--min-reduction', action='store_true',
|
|
help='enable minimum reduction in adversarial criterion')
|
|
parser.add_argument('--cgan', action='store_true',
|
|
help='enable conditional GAN')
|
|
parser.add_argument('--adv-start', default=0, type=int,
|
|
help='epoch to start adversarial training')
|
|
parser.add_argument('--adv-label-smoothing', default=1, type=float,
|
|
help='label of real samples for the adversary model, '
|
|
'e.g. 0.9 for label smoothing and 1 to disable')
|
|
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
|
help='final fraction of loss (vs adv-loss)')
|
|
parser.add_argument('--instance-noise', default=0, type=float,
|
|
help='noise added to the adversary inputs to stabilize training')
|
|
parser.add_argument('--instance-noise-batches', default=1e4, type=float,
|
|
help='noise annealing duration')
|
|
|
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
|
help='optimizer from torch.optim')
|
|
parser.add_argument('--lr', default=0.001, type=float,
|
|
help='initial learning rate')
|
|
# parser.add_argument('--momentum', default=0.9, type=float,
|
|
# help='momentum')
|
|
parser.add_argument('--weight-decay', default=0, type=float,
|
|
help='weight decay')
|
|
parser.add_argument('--adv-lr', type=float,
|
|
help='initial adversary learning rate')
|
|
parser.add_argument('--adv-weight-decay', type=float,
|
|
help='adversary weight decay')
|
|
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
|
help='Enable ReduceLROnPlateau learning rate scheduler')
|
|
parser.add_argument('--init-weight-std', type=float,
|
|
help='weight initialization std')
|
|
parser.add_argument('--epochs', default=128, type=int,
|
|
help='total number of epochs to run')
|
|
parser.add_argument('--seed', default=42, type=int,
|
|
help='seed for initializing training')
|
|
|
|
parser.add_argument('--div-data', action='store_true',
|
|
help='enable data division among GPUs, useful with cache')
|
|
parser.add_argument('--dist-backend', default='nccl', type=str,
|
|
choices=['gloo', 'nccl'], help='distributed backend')
|
|
parser.add_argument('--log-interval', default=100, type=int,
|
|
help='interval between logging training loss')
|
|
|
|
|
|
def add_test_args(parser):
|
|
add_common_args(parser)
|
|
|
|
parser.add_argument('--test-in-patterns', type=str_list, required=True,
|
|
help='comma-sep. list of glob patterns for test input data')
|
|
parser.add_argument('--test-tgt-patterns', type=str_list, required=True,
|
|
help='comma-sep. list of glob patterns for test target data')
|
|
|
|
|
|
def str_list(s):
|
|
return s.split(',')
|
|
|
|
|
|
#def int_tuple(t):
|
|
# t = t.split(',')
|
|
# t = tuple(int(i) for i in t)
|
|
# if len(t) == 1:
|
|
# t = t[0]
|
|
# elif len(t) != 6:
|
|
# raise ValueError('size must be int or 6-tuple')
|
|
# return t
|
|
|
|
|
|
def set_common_args(args):
|
|
if args.loader_workers is None:
|
|
args.loader_workers = 0
|
|
if args.batches > 1:
|
|
args.loader_workers = args.batches
|
|
|
|
|
|
def set_train_args(args):
|
|
set_common_args(args)
|
|
|
|
args.val = args.val_in_patterns is not None and \
|
|
args.val_tgt_patterns is not None
|
|
|
|
args.adv = args.adv_model is not None
|
|
|
|
if args.adv:
|
|
if args.adv_lr is None:
|
|
args.adv_lr = args.lr
|
|
if args.adv_weight_decay is None:
|
|
args.adv_weight_decay = args.weight_decay
|
|
|
|
|
|
def set_test_args(args):
|
|
set_common_args(args)
|