From 4579651bba8f871b6815556ab6686f9c31eb72df Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Wed, 3 Apr 2024 17:00:41 +0200 Subject: [PATCH] Add much more flexible argument handling --- map2map/common_args.yaml | 140 +++++++++++++++++++-------------------- map2map/main.py | 54 ++++++++++++--- pyproject.toml | 5 +- 3 files changed, 118 insertions(+), 81 deletions(-) diff --git a/map2map/common_args.yaml b/map2map/common_args.yaml index 0d0e18b..4fdaa68 100644 --- a/map2map/common_args.yaml +++ b/map2map/common_args.yaml @@ -1,88 +1,88 @@ arguments: - 'in-norms': - type: str_list - help: 'comma-sep. list of input normalization functions' + type: str_list + help: 'comma-sep. list of input normalization functions' - 'tgt-norms': - type: str_list - help: 'comma-sep. list of target normalization functions' + type: str_list + help: 'comma-sep. list of target normalization functions' - 'crop': - type: int_tuple - help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers' + type: int_tuple + help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers' - 'crop-start': - type: int_tuple - help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers' + type: int_tuple + help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers' - 'crop-stop': - type: int_tuple - help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers' + type: int_tuple + help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers' - 'crop-step': - type: int_tuple - help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' + type: int_tuple + help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' - 'in-pad': - 'pad': 0 - type: int_tuple - help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' + pad: 0 + type: int_tuple + help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' - 'tgt-pad': - default: 0 - type: int_tuple - help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad' + default: 0 + type: int_tuple + help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad' - 'scale-factor': - default: 1 - type: int - help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution' + default: 1 + type: int + help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution' - 'model': - type: str - required: true - help: '(generator) model' + type: str + required: true + help: '(generator) model' - 'criterion': - default: 'MSELoss' - type: str - help: 'loss function' + default: 'MSELoss' + type: str + help: 'loss function' - 'load-state': - default: ckpt_link - type: str - help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint' + default: ckpt_link + type: str + help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint' - 'load-state-non-strict': - action: 'store_false' - help: 'allow incompatible keys when loading model states' - dest: 'load_state_strict' + action: 'store_false' + help: 'allow incompatible keys when loading model states' + dest: 'load_state_strict' - 'batch-size': - 'batches': 0 - type: int - required: true - help: 'mini-batch size, per GPU in training or in total in testing' + 'batches': 0 + type: int + required: true + help: 'mini-batch size, per GPU in training or in total in testing' - 'loader-workers': - default: 8 - type: int - help: 'number of subprocesses per data loader. 0 to disable multiprocessing' + default: 8 + type: int + help: 'number of subprocesses per data loader. 0 to disable multiprocessing' - 'callback-at': - type: 'lambda s: os.path.abspath(s)' - help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority' + type: 'lambda s: os.path.abspath(s)' + help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority' - 'misc-kwargs': - default: '{}' - type: json.loads - help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions' - arguments: - - 'optimizer': - default: 'Adam' - type: str - help: 'optimizer for training' - - 'learning-rate': - default: 0.001 - type: float - help: 'learning rate for training' - - 'num-epochs': - default: 100 - type: int - help: 'number of training epochs' - - 'save-interval': - default: 10 - type: int - help: 'interval for saving checkpoints during training' - - 'log-interval': - default: 10 - type: int - help: 'interval for logging training progress' - - 'device': - default: 'cuda' - type: str - help: 'device for training (cuda or cpu)' \ No newline at end of file + default: '{}' + type: json.loads + help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions' + arguments: + - 'optimizer': + default: 'Adam' + type: str + help: 'optimizer for training' + - 'learning-rate': + default: 0.001 + type: float + help: 'learning rate for training' + - 'num-epochs': + default: 100 + type: int + help: 'number of training epochs' + - 'save-interval': + default: 10 + type: int + help: 'interval for saving checkpoints during training' + - 'log-interval': + default: 10 + type: int + help: 'interval for logging training progress' + - 'device': + default: 'cuda' + type: str + help: 'device for training (cuda or cpu)' \ No newline at end of file diff --git a/map2map/main.py b/map2map/main.py index a8fa1dc..319e85f 100644 --- a/map2map/main.py +++ b/map2map/main.py @@ -1,17 +1,53 @@ from .args import get_args from . import train from . import test +import click +import os +import yaml +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader +import importlib.resources -def main(): +def _load_resource_file(resource_path): + package = importlib.import_module('map2map') # Import the package + with importlib.resources.path('map2map', resource_path) as path: + return path.read_text() # Read the file and return its content +def str_list(s): + return s.split(',') + +def m2m_options(f): + common_args = _load_resource_file('common_args.yaml') + + for arg in common_args['arguments']: + argopt = common_args[arg] + if 'type' in argopt: + argopt['type'] = eval(argopt['type']) + f = click.option(f'--{arg}', **argopt)(f) + else: + f = click.option(f'--{arg}', **argopt)(f) + + return f + +@click.group() +@click.option("--config", type=click.Path()) +@click.pass_context +def main(ctx, config): + if os.path.exists(config): + with open(config, 'r') as f: + config = yaml.load(f.read(), Loader=Loader) + ctx.default_map = config + +@main.command() +@m2m_options +def train(**kwargs): args = get_args() + train.node_worker(args) - if args.mode == 'train': - train.node_worker(args) - elif args.mode == 'test': - test.test(args) - - -if __name__ == '__main__': - main() +@main.command() +@m2m_options +def test(): + test.test(args) diff --git a/pyproject.toml b/pyproject.toml index b90f7ef..098c150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ 'numpy', 'scipy', 'matplotlib', - 'tensorboard'] + 'tensorboard', + 'click'] authors = [ {name = "Yin Li", email = "eelregit@gmail.com"}, @@ -29,7 +30,7 @@ maintainers = [ ] [project.scripts] -m2m = "map2map:main" +m2m = "map2map:main.main" [project.urls] #Homepage = "https://example.com"