map2map/map2map/args.py

333 lines
9.8 KiB
Python

import os
import argparse
import json
import warnings
from .train import ckpt_link
def get_args():
"""Parse arguments and set runtime defaults."""
parser = argparse.ArgumentParser(description="Transform field(s) to field(s)")
subparsers = parser.add_subparsers(title="modes", dest="mode", required=True)
train_parser = subparsers.add_parser(
"train",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
test_parser = subparsers.add_parser(
"test",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
add_train_args(train_parser)
add_test_args(test_parser)
args = parser.parse_args()
if args.mode == "train":
set_train_args(args)
elif args.mode == "test":
set_test_args(args)
return args
def add_common_args(parser):
parser.add_argument(
"--in-norms",
type=str_list,
help="comma-sep. list " "of input normalization functions",
)
parser.add_argument(
"--tgt-norms",
type=str_list,
help="comma-sep. list " "of target normalization functions",
)
parser.add_argument(
"--crop",
type=int_tuple,
help="size to crop the input and target data. Default is the "
"field size. Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--crop-start",
type=int_tuple,
help="starting point of the first crop. Default is the origin. "
"Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--crop-stop",
type=int_tuple,
help="stopping point of the last crop. Default is the opposite "
"corner to the origin. Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--crop-step",
type=int_tuple,
help="spacing between crops. Default is the crop size. "
"Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--in-pad",
"--pad",
default=0,
type=int_tuple,
help="size to pad the input data beyond the crop size, assuming "
"periodic boundary condition. Comma-sep. list of 1, d, or dx2 "
"integers, to pad equally along all axes, symmetrically on each, "
"or by the specified size on every boundary, respectively",
)
parser.add_argument(
"--tgt-pad",
default=0,
type=int_tuple,
help="size to pad the target data beyond the crop size, assuming "
"periodic boundary condition, useful for super-resolution. "
"Comma-sep. list with the same format as --in-pad",
)
parser.add_argument(
"--scale-factor",
default=1,
type=int,
help="upsampling factor for super-resolution, in which case "
"crop and pad are sizes of the input resolution",
)
parser.add_argument("--model", type=str, required=True, help="(generator) model")
parser.add_argument(
"--criterion", default="MSELoss", type=str, help="loss function"
)
parser.add_argument(
"--load-state",
default=ckpt_link,
type=str,
help="path to load the states of model, optimizer, rng, etc. "
"Default is the checkpoint. "
"Start from scratch in case of empty string or missing checkpoint",
)
parser.add_argument(
"--load-state-non-strict",
action="store_false",
help="allow incompatible keys when loading model states",
dest="load_state_strict",
)
# somehow I named it "batches" instead of batch_size at first
# "batches" is kept for now for backward compatibility
parser.add_argument(
"--batch-size",
"--batches",
type=int,
required=True,
help="mini-batch size, per GPU in training or in total in testing",
)
parser.add_argument(
"--loader-workers",
default=8,
type=int,
help="number of subprocesses per data loader. " "0 to disable multiprocessing",
)
parser.add_argument(
"--callback-at",
type=lambda s: os.path.abspath(s),
help="directory of custorm code defining callbacks for models, "
"norms, criteria, and optimizers. Disabled if not set. "
"This is appended to the default locations, "
"thus has the lowest priority",
)
parser.add_argument(
"--misc-kwargs",
default="{}",
type=json.loads,
help="miscellaneous keyword arguments for custom models and "
"norms. Be careful with name collisions",
)
def add_train_args(parser):
add_common_args(parser)
parser.add_argument(
"--train-style-pattern",
type=str,
required=True,
help="glob pattern for training data styles",
)
parser.add_argument(
"--train-in-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for training input data",
)
parser.add_argument(
"--train-tgt-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for training target data",
)
parser.add_argument(
"--val-style-pattern", type=str, help="glob pattern for validation data styles"
)
parser.add_argument(
"--val-in-patterns",
type=str_list,
help="comma-sep. list of glob patterns for validation input data",
)
parser.add_argument(
"--val-tgt-patterns",
type=str_list,
help="comma-sep. list of glob patterns for validation target data",
)
parser.add_argument(
"--augment",
action="store_true",
help="enable data augmentation of axis flipping and permutation",
)
parser.add_argument(
"--aug-shift",
type=int_tuple,
help="data augmentation by shifting cropping by [0, aug_shift) pixels, "
"useful for models that treat neighboring pixels differently, "
"e.g. with strided convolutions. "
"Comma-sep. list of 1 or d integers",
)
parser.add_argument(
"--aug-add",
type=float,
help="additive data augmentation, (normal) std, " "same factor for all fields",
)
parser.add_argument(
"--aug-mul",
type=float,
help="multiplicative data augmentation, (log-normal) std, "
"same factor for all fields",
)
parser.add_argument(
"--optimizer", default="Adam", type=str, help="optimization algorithm"
)
parser.add_argument("--lr", type=float, required=True, help="initial learning rate")
parser.add_argument(
"--optimizer-args",
default="{}",
type=json.loads,
help="optimizer arguments in addition to the learning rate, "
"e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'",
)
parser.add_argument(
"--reduce-lr-on-plateau",
action="store_true",
help="Enable ReduceLROnPlateau learning rate scheduler",
)
parser.add_argument(
"--scheduler-args",
default='{"verbose": true}',
type=json.loads,
help="arguments for the ReduceLROnPlateau scheduler",
)
parser.add_argument(
"--init-weight-std", type=float, help="weight initialization std"
)
parser.add_argument(
"--epochs", default=128, type=int, help="total number of epochs to run"
)
parser.add_argument(
"--seed", default=42, type=int, help="seed for initializing training"
)
parser.add_argument(
"--div-data",
action="store_true",
help="enable data division among GPUs for better page caching. "
"Data division is shuffled every epoch. "
"Only relevant if there are multiple crops in each field",
)
parser.add_argument(
"--div-shuffle-dist",
default=1,
type=float,
help="distance to further shuffle cropped samples relative to "
"their fields, to be used with --div-data. "
"Only relevant if there are multiple crops in each file. "
"The order of each sample is randomly displaced by this value. "
"Setting it to 0 turn off this randomization, and setting it to N "
"limits the shuffling within a distance of N files. "
"Change this to balance cache locality and stochasticity",
)
parser.add_argument(
"--dist-backend",
default="nccl",
type=str,
choices=["gloo", "nccl"],
help="distributed backend",
)
parser.add_argument(
"--log-interval",
default=100,
type=int,
help="interval (batches) between logging training loss",
)
parser.add_argument(
"--detect-anomaly",
action="store_true",
help="enable anomaly detection for the autograd engine",
)
def add_test_args(parser):
add_common_args(parser)
parser.add_argument(
"--test-style-pattern",
type=str,
required=True,
help="glob pattern for test data styles",
)
parser.add_argument(
"--test-in-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for test input data",
)
parser.add_argument(
"--test-tgt-patterns",
type=str_list,
required=True,
help="comma-sep. list of glob patterns for test target data",
)
parser.add_argument(
"--num-threads",
type=int,
help="number of CPU threads when cuda is unavailable. "
"Default is the number of CPUs on the node by slurm",
)
def str_list(s):
return s.split(",")
def int_tuple(s):
t = s.split(",")
t = tuple(int(i) for i in t)
if len(t) == 1:
return t[0]
else:
return t
def set_common_args(args):
pass
def set_train_args(args):
set_common_args(args)
args.val = args.val_in_patterns is not None and args.val_tgt_patterns is not None
def set_test_args(args):
set_common_args(args)