From 8bf3dd841e9ae3c1bb8f8a96b6fc0737186915cc Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Wed, 3 Apr 2024 18:53:44 +0200 Subject: [PATCH] Add more arguments --- map2map/common_args.yaml | 175 ++++++++++++++++++++------------------- map2map/main.py | 90 ++++++++++++++++---- map2map/test_args.yaml | 16 ++++ map2map/train_args.yaml | 86 +++++++++++++++++++ pyproject.toml | 2 +- 5 files changed, 263 insertions(+), 106 deletions(-) create mode 100644 map2map/test_args.yaml create mode 100644 map2map/train_args.yaml diff --git a/map2map/common_args.yaml b/map2map/common_args.yaml index 4fdaa68..609f10e 100644 --- a/map2map/common_args.yaml +++ b/map2map/common_args.yaml @@ -1,88 +1,89 @@ arguments: - - 'in-norms': - type: str_list - help: 'comma-sep. list of input normalization functions' - - 'tgt-norms': - type: str_list - help: 'comma-sep. list of target normalization functions' - - 'crop': - type: int_tuple - help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers' - - 'crop-start': - type: int_tuple - help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers' - - 'crop-stop': - type: int_tuple - help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers' - - 'crop-step': - type: int_tuple - help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' - - 'in-pad': - pad: 0 - type: int_tuple - help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' - - 'tgt-pad': - default: 0 - type: int_tuple - help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad' - - 'scale-factor': - default: 1 - type: int - help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution' - - 'model': - type: str - required: true - help: '(generator) model' - - 'criterion': - default: 'MSELoss' - type: str - help: 'loss function' - - 'load-state': - default: ckpt_link - type: str - help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint' - - 'load-state-non-strict': - action: 'store_false' - help: 'allow incompatible keys when loading model states' - dest: 'load_state_strict' - - 'batch-size': - 'batches': 0 - type: int - required: true - help: 'mini-batch size, per GPU in training or in total in testing' - - 'loader-workers': - default: 8 - type: int - help: 'number of subprocesses per data loader. 0 to disable multiprocessing' - - 'callback-at': - type: 'lambda s: os.path.abspath(s)' - help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority' - - 'misc-kwargs': - default: '{}' - type: json.loads - help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions' - arguments: - - 'optimizer': - default: 'Adam' - type: str - help: 'optimizer for training' - - 'learning-rate': - default: 0.001 - type: float - help: 'learning rate for training' - - 'num-epochs': - default: 100 - type: int - help: 'number of training epochs' - - 'save-interval': - default: 10 - type: int - help: 'interval for saving checkpoints during training' - - 'log-interval': - default: 10 - type: int - help: 'interval for logging training progress' - - 'device': - default: 'cuda' - type: str - help: 'device for training (cuda or cpu)' \ No newline at end of file + 'in-norms': + type: str_list + help: 'comma-sep. list of input normalization functions' + 'tgt-norms': + type: str_list + help: 'comma-sep. list of target normalization functions' + 'crop': + type: int_tuple + help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers' + 'crop-start': + type: int_tuple + help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers' + 'crop-stop': + type: int_tuple + help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers' + 'crop-step': + type: int_tuple + help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' + 'in-pad': + default: 0 + type: int_tuple + help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' + 'tgt-pad': + default: 0 + type: int_tuple + help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad' + 'scale-factor': + default: 1 + type: int + help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution' + 'model': + type: str + required: true + help: '(generator) model' + 'criterion': + default: 'MSELoss' + type: str + help: 'loss function' + 'load-state': + default: ckpt_link + type: str + help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint' + 'load-state-non-strict': + # action: 'store_false' + help: 'allow incompatible keys when loading model states' + # dest: 'load_state_strict' + 'batch-size': + default: 0 + type: int + required: true + help: 'mini-batch size, per GPU in training or in total in testing' + 'loader-workers': + default: 8 + type: int + help: 'number of subprocesses per data loader. 0 to disable multiprocessing' + 'callback-at': + type: 'abspath' + help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority' + 'misc-kwargs': + default: '{}' + type: json + help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions' + +# arguments: +# - 'optimizer': +# default: 'Adam' +# type: str +# help: 'optimizer for training' +# - 'learning-rate': +# default: 0.001 +# type: float +# help: 'learning rate for training' +# - 'num-epochs': +# default: 100 +# type: int +# help: 'number of training epochs' +# - 'save-interval': +# default: 10 +# type: int +# help: 'interval for saving checkpoints during training' +# - 'log-interval': +# default: 10 +# type: int +# help: 'interval for logging training progress' +# - 'device': +# default: 'cuda' +# type: str +# help: 'device for training (cuda or cpu)' \ No newline at end of file diff --git a/map2map/main.py b/map2map/main.py index 319e85f..03a10cb 100644 --- a/map2map/main.py +++ b/map2map/main.py @@ -10,44 +10,98 @@ except ImportError: from yaml import Loader import importlib.resources +import json +from functools import partial def _load_resource_file(resource_path): - package = importlib.import_module('map2map') # Import the package - with importlib.resources.path('map2map', resource_path) as path: - return path.read_text() # Read the file and return its content + # Import the package + pkg_files = importlib.resources.files() + with pkg_files.open(resource_path) as file: + return file.read_text() # Read the file and return its content -def str_list(s): - return s.split(',') +def _str_list(value): + return value.split(',') -def m2m_options(f): - common_args = _load_resource_file('common_args.yaml') +def _int_tuple(value): + t = value.split(',') + t = tuple(int(i) for i in t) + return t - for arg in common_args['arguments']: +class VariadicType(click.ParamType): + + _mapper = { + "str_list": {"type": "string_list", "func": _str_list}, + "int_tuple": {"type": "int_tuple", "func": _int_tuple}, + "json": {"type": "json", "func": json.loads}, + "int": {"type": "int"}, + "float": {"type": "float"}, + "str": {"type": "str"}, + "abspath": {"type": "path", "func": os.path.abspath}, + } + + def __init__(self, typename): + if typename in self._mapper: + self._type = self._mapper[typename] + elif type(typename) == dict: + self._type = self._mapper[typename["type"]] + self.args = typename["opts"] + else: + raise ValueError(f"Unknown type: {typename}") + self._typename = typename + self.name = self._type["type"] + if "func" not in self._type: + self._type["func"] = eval(self._type['type']) + + def convert(self, value, param, ctx): + try: + return self.type(value) + except Exception as e: + self.fail(f"Could not parse {self._typename}: {e}", param, ctx) + + +def _apply_options(options_file, f): + common_args = yaml.load(_load_resource_file(options_file), Loader=Loader) + common_args = common_args['arguments'] + + for arg in common_args: argopt = common_args[arg] if 'type' in argopt: - argopt['type'] = eval(argopt['type']) + if type(argopt['type']) == dict and argopt['type']['type'] == 'choice': + argopt['type'] = click.Choice(argopt['type']['opts']) + else: + argopt['type'] = VariadicType(argopt['type']) f = click.option(f'--{arg}', **argopt)(f) else: f = click.option(f'--{arg}', **argopt)(f) return f +def m2m_options(f): + return _apply_options("common_args.yaml", f) + + @click.group() -@click.option("--config", type=click.Path()) +@click.option("--config", type=click.Path(), help="Path to config file") @click.pass_context def main(ctx, config): - if os.path.exists(config): + if config is not None and os.path.exists(config): with open(config, 'r') as f: config = yaml.load(f.read(), Loader=Loader) ctx.default_map = config -@main.command() -@m2m_options -def train(**kwargs): - args = get_args() - train.node_worker(args) +# Make a class that provides access to dict with the attribute mechanism +class DictProxy: + def __init__(self, d): + self.__dict__ = d @main.command() @m2m_options -def test(): - test.test(args) +@partial(_apply_options, "train_args.yaml") +def train(**kwargs): + train.node_worker(DictProxy(kwargs)) + +@main.command() +@m2m_options +@partial(_apply_options, "test_args.yaml") +def test(**kwargs): + test.test(DictProxy(kwargs)) diff --git a/map2map/test_args.yaml b/map2map/test_args.yaml new file mode 100644 index 0000000..4de0f97 --- /dev/null +++ b/map2map/test_args.yaml @@ -0,0 +1,16 @@ +arguments: + 'test-style-pattern': + type: str + required: true + help: glob pattern for test data styles + 'test-in-patterns': + type: str_list + required: true + help: comma-sep. list of glob patterns for test input data + 'test-tgt-patterns': + type: str_list + required: true + help: comma-sep. list of glob patterns for test target data + 'num-threads': + type: int + help: number of CPU threads when cuda is unavailable. Default is the number of CPUs on the node by slurm diff --git a/map2map/train_args.yaml b/map2map/train_args.yaml new file mode 100644 index 0000000..4193a0c --- /dev/null +++ b/map2map/train_args.yaml @@ -0,0 +1,86 @@ +arguments: + 'train-style-pattern': + type: str + required: true + help: 'glob pattern for training data styles' + 'train-in-patterns': + type: str_list + required: true + help: 'comma-sep. list of glob patterns for training input data' + 'train-tgt-patterns': + type: str_list + required: true + help: 'comma-sep. list of glob patterns for training target data' + 'val-style-pattern': + type: str + help: 'glob pattern for validation data styles' + 'val-in-patterns': + type: str_list + help: 'comma-sep. list of glob patterns for validation input data' + 'val-tgt-patterns': + type: str_list + help: 'comma-sep. list of glob patterns for validation target data' + 'augment': + is_flag: true + help: 'enable data augmentation of axis flipping and permutation' + 'aug-shift': + type: int_tuple + help: 'data augmentation by shifting cropping by [0, aug_shift) pixels, useful for models that treat neighboring pixels differently, e.g. with strided convolutions. Comma-sep. list of 1 or d integers' + 'aug-add': + type: float + help: 'additive data augmentation, (normal) std, same factor for all fields' + 'aug-mul': + type: float + help: 'multiplicative data augmentation, (log-normal) std, same factor for all fields' + 'optimizer': + default: 'Adam' + type: str + help: 'optimization algorithm' + 'lr': + type: float + required: true + help: 'initial learning rate' + 'optimizer-args': + default: '{}' + type: json + help: "optimizer arguments in addition to the learning rate, e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'" + 'reduce-lr-on-plateau': + is_flag: true + help: 'Enable ReduceLROnPlateau learning rate scheduler' + 'scheduler-args': + default: '{"verbose": true}' + type: json + help: 'arguments for the ReduceLROnPlateau scheduler' + 'init-weight-std': + type: float + help: 'weight initialization std' + 'epochs': + default: 128 + type: int + help: 'total number of epochs to run' + 'seed': + default: 42 + type: int + help: 'seed for initializing training' + 'div-data': + is_flag: true + help: 'enable data division among GPUs for better page caching. Data division is shuffled every epoch. Only relevant if there are multiple crops in each field' + 'div-shuffle-dist': + default: 1 + type: float + help: 'distance to further shuffle cropped samples relative to their fields, to be used with --div-data. Only relevant if there are multiple crops in each file. The order of each sample is randomly displaced by this value. Setting it to 0 turn off this randomization, and setting it to N limits the shuffling within a distance of N files. Change this to balance cache locality and stochasticity' + 'dist-backend': + default: 'nccl' + type: + type: "choice" + opts: + - 'gloo' + - 'nccl' + help: 'distributed backend' + 'log-interval': + default: 100 + type: int + help: 'interval (batches) between logging training loss' + 'detect-anomaly': + is_flag: true + help: 'enable anomaly detection for the autograd engine' \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7214b5f..5e79339 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ 'scipy', 'matplotlib', 'tensorboard', - 'click'] + 'click','pyyaml'] authors = [ {name = "Yin Li", email = "eelregit@gmail.com"},