Add more arguments

This commit is contained in:
Guilhem Lavaux 2024-04-03 18:53:44 +02:00
parent 813a4e95ee
commit c2a161d5e1
5 changed files with 263 additions and 106 deletions

View File

@ -1,88 +1,89 @@
arguments: arguments:
- 'in-norms': 'in-norms':
type: str_list type: str_list
help: 'comma-sep. list of input normalization functions' help: 'comma-sep. list of input normalization functions'
- 'tgt-norms': 'tgt-norms':
type: str_list type: str_list
help: 'comma-sep. list of target normalization functions' help: 'comma-sep. list of target normalization functions'
- 'crop': 'crop':
type: int_tuple type: int_tuple
help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers' help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers'
- 'crop-start': 'crop-start':
type: int_tuple type: int_tuple
help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers' help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers'
- 'crop-stop': 'crop-stop':
type: int_tuple type: int_tuple
help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers' help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers'
- 'crop-step': 'crop-step':
type: int_tuple type: int_tuple
help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers'
- 'in-pad': 'in-pad':
pad: 0 default: 0
type: int_tuple type: int_tuple
help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively'
- 'tgt-pad': 'tgt-pad':
default: 0 default: 0
type: int_tuple type: int_tuple
help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad' help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad'
- 'scale-factor': 'scale-factor':
default: 1 default: 1
type: int type: int
help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution' help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution'
- 'model': 'model':
type: str type: str
required: true required: true
help: '(generator) model' help: '(generator) model'
- 'criterion': 'criterion':
default: 'MSELoss' default: 'MSELoss'
type: str type: str
help: 'loss function' help: 'loss function'
- 'load-state': 'load-state':
default: ckpt_link default: ckpt_link
type: str type: str
help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint' help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint'
- 'load-state-non-strict': 'load-state-non-strict':
action: 'store_false' # action: 'store_false'
help: 'allow incompatible keys when loading model states' help: 'allow incompatible keys when loading model states'
dest: 'load_state_strict' # dest: 'load_state_strict'
- 'batch-size': 'batch-size':
'batches': 0 default: 0
type: int type: int
required: true required: true
help: 'mini-batch size, per GPU in training or in total in testing' help: 'mini-batch size, per GPU in training or in total in testing'
- 'loader-workers': 'loader-workers':
default: 8 default: 8
type: int type: int
help: 'number of subprocesses per data loader. 0 to disable multiprocessing' help: 'number of subprocesses per data loader. 0 to disable multiprocessing'
- 'callback-at': 'callback-at':
type: 'lambda s: os.path.abspath(s)' type: 'abspath'
help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority' help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority'
- 'misc-kwargs': 'misc-kwargs':
default: '{}' default: '{}'
type: json.loads type: json
help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions' help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions'
arguments:
- 'optimizer': # arguments:
default: 'Adam' # - 'optimizer':
type: str # default: 'Adam'
help: 'optimizer for training' # type: str
- 'learning-rate': # help: 'optimizer for training'
default: 0.001 # - 'learning-rate':
type: float # default: 0.001
help: 'learning rate for training' # type: float
- 'num-epochs': # help: 'learning rate for training'
default: 100 # - 'num-epochs':
type: int # default: 100
help: 'number of training epochs' # type: int
- 'save-interval': # help: 'number of training epochs'
default: 10 # - 'save-interval':
type: int # default: 10
help: 'interval for saving checkpoints during training' # type: int
- 'log-interval': # help: 'interval for saving checkpoints during training'
default: 10 # - 'log-interval':
type: int # default: 10
help: 'interval for logging training progress' # type: int
- 'device': # help: 'interval for logging training progress'
default: 'cuda' # - 'device':
type: str # default: 'cuda'
help: 'device for training (cuda or cpu)' # type: str
# help: 'device for training (cuda or cpu)'

View File

@ -10,44 +10,98 @@ except ImportError:
from yaml import Loader from yaml import Loader
import importlib.resources import importlib.resources
import json
from functools import partial
def _load_resource_file(resource_path): def _load_resource_file(resource_path):
package = importlib.import_module('map2map') # Import the package # Import the package
with importlib.resources.path('map2map', resource_path) as path: pkg_files = importlib.resources.files()
return path.read_text() # Read the file and return its content with pkg_files.open(resource_path) as file:
return file.read_text() # Read the file and return its content
def str_list(s): def _str_list(value):
return s.split(',') return value.split(',')
def m2m_options(f): def _int_tuple(value):
common_args = _load_resource_file('common_args.yaml') t = value.split(',')
t = tuple(int(i) for i in t)
return t
for arg in common_args['arguments']: class VariadicType(click.ParamType):
_mapper = {
"str_list": {"type": "string_list", "func": _str_list},
"int_tuple": {"type": "int_tuple", "func": _int_tuple},
"json": {"type": "json", "func": json.loads},
"int": {"type": "int"},
"float": {"type": "float"},
"str": {"type": "str"},
"abspath": {"type": "path", "func": os.path.abspath},
}
def __init__(self, typename):
if typename in self._mapper:
self._type = self._mapper[typename]
elif type(typename) == dict:
self._type = self._mapper[typename["type"]]
self.args = typename["opts"]
else:
raise ValueError(f"Unknown type: {typename}")
self._typename = typename
self.name = self._type["type"]
if "func" not in self._type:
self._type["func"] = eval(self._type['type'])
def convert(self, value, param, ctx):
try:
return self.type(value)
except Exception as e:
self.fail(f"Could not parse {self._typename}: {e}", param, ctx)
def _apply_options(options_file, f):
common_args = yaml.load(_load_resource_file(options_file), Loader=Loader)
common_args = common_args['arguments']
for arg in common_args:
argopt = common_args[arg] argopt = common_args[arg]
if 'type' in argopt: if 'type' in argopt:
argopt['type'] = eval(argopt['type']) if type(argopt['type']) == dict and argopt['type']['type'] == 'choice':
argopt['type'] = click.Choice(argopt['type']['opts'])
else:
argopt['type'] = VariadicType(argopt['type'])
f = click.option(f'--{arg}', **argopt)(f) f = click.option(f'--{arg}', **argopt)(f)
else: else:
f = click.option(f'--{arg}', **argopt)(f) f = click.option(f'--{arg}', **argopt)(f)
return f return f
def m2m_options(f):
return _apply_options("common_args.yaml", f)
@click.group() @click.group()
@click.option("--config", type=click.Path()) @click.option("--config", type=click.Path(), help="Path to config file")
@click.pass_context @click.pass_context
def main(ctx, config): def main(ctx, config):
if os.path.exists(config): if config is not None and os.path.exists(config):
with open(config, 'r') as f: with open(config, 'r') as f:
config = yaml.load(f.read(), Loader=Loader) config = yaml.load(f.read(), Loader=Loader)
ctx.default_map = config ctx.default_map = config
@main.command() # Make a class that provides access to dict with the attribute mechanism
@m2m_options class DictProxy:
def train(**kwargs): def __init__(self, d):
args = get_args() self.__dict__ = d
train.node_worker(args)
@main.command() @main.command()
@m2m_options @m2m_options
def test(): @partial(_apply_options, "train_args.yaml")
test.test(args) def train(**kwargs):
train.node_worker(DictProxy(kwargs))
@main.command()
@m2m_options
@partial(_apply_options, "test_args.yaml")
def test(**kwargs):
test.test(DictProxy(kwargs))

16
map2map/test_args.yaml Normal file
View File

@ -0,0 +1,16 @@
arguments:
'test-style-pattern':
type: str
required: true
help: glob pattern for test data styles
'test-in-patterns':
type: str_list
required: true
help: comma-sep. list of glob patterns for test input data
'test-tgt-patterns':
type: str_list
required: true
help: comma-sep. list of glob patterns for test target data
'num-threads':
type: int
help: number of CPU threads when cuda is unavailable. Default is the number of CPUs on the node by slurm

86
map2map/train_args.yaml Normal file
View File

@ -0,0 +1,86 @@
arguments:
'train-style-pattern':
type: str
required: true
help: 'glob pattern for training data styles'
'train-in-patterns':
type: str_list
required: true
help: 'comma-sep. list of glob patterns for training input data'
'train-tgt-patterns':
type: str_list
required: true
help: 'comma-sep. list of glob patterns for training target data'
'val-style-pattern':
type: str
help: 'glob pattern for validation data styles'
'val-in-patterns':
type: str_list
help: 'comma-sep. list of glob patterns for validation input data'
'val-tgt-patterns':
type: str_list
help: 'comma-sep. list of glob patterns for validation target data'
'augment':
is_flag: true
help: 'enable data augmentation of axis flipping and permutation'
'aug-shift':
type: int_tuple
help: 'data augmentation by shifting cropping by [0, aug_shift) pixels, useful for models that treat neighboring pixels differently, e.g. with strided convolutions. Comma-sep. list of 1 or d integers'
'aug-add':
type: float
help: 'additive data augmentation, (normal) std, same factor for all fields'
'aug-mul':
type: float
help: 'multiplicative data augmentation, (log-normal) std, same factor for all fields'
'optimizer':
default: 'Adam'
type: str
help: 'optimization algorithm'
'lr':
type: float
required: true
help: 'initial learning rate'
'optimizer-args':
default: '{}'
type: json
help: "optimizer arguments in addition to the learning rate, e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'"
'reduce-lr-on-plateau':
is_flag: true
help: 'Enable ReduceLROnPlateau learning rate scheduler'
'scheduler-args':
default: '{"verbose": true}'
type: json
help: 'arguments for the ReduceLROnPlateau scheduler'
'init-weight-std':
type: float
help: 'weight initialization std'
'epochs':
default: 128
type: int
help: 'total number of epochs to run'
'seed':
default: 42
type: int
help: 'seed for initializing training'
'div-data':
is_flag: true
help: 'enable data division among GPUs for better page caching. Data division is shuffled every epoch. Only relevant if there are multiple crops in each field'
'div-shuffle-dist':
default: 1
type: float
help: 'distance to further shuffle cropped samples relative to their fields, to be used with --div-data. Only relevant if there are multiple crops in each file. The order of each sample is randomly displaced by this value. Setting it to 0 turn off this randomization, and setting it to N limits the shuffling within a distance of N files. Change this to balance cache locality and stochasticity'
'dist-backend':
default: 'nccl'
type:
type: "choice"
opts:
- 'gloo'
- 'nccl'
help: 'distributed backend'
'log-interval':
default: 100
type: int
help: 'interval (batches) between logging training loss'
'detect-anomaly':
is_flag: true
help: 'enable anomaly detection for the autograd engine'

View File

@ -38,7 +38,7 @@ dependencies = [
'scipy', 'scipy',
'matplotlib', 'matplotlib',
'tensorboard', 'tensorboard',
'click'] 'click','pyyaml']
authors = [ authors = [
{name = "Yin Li", email = "eelregit@gmail.com"}, {name = "Yin Li", email = "eelregit@gmail.com"},