From 47cdaffd81099be244fba2212f8966b6ca375721 Mon Sep 17 00:00:00 2001 From: Guilhem Lavaux Date: Wed, 3 Apr 2024 16:27:39 +0200 Subject: [PATCH] feat: Use click and autogenerated arguments from YAML parameter file BREAKING CHANGE: arguments may not be propagated exactly the same way as previously --- .gitignore | 1 + README.md | 5 +- map2map/common_args.yaml | 89 +++++++++++++++++++++++++++++ map2map/cropper.py | 35 ++++++++++++ map2map/main.py | 119 +++++++++++++++++++++++++++++++++++---- map2map/test_args.yaml | 16 ++++++ map2map/train_args.yaml | 86 ++++++++++++++++++++++++++++ pyproject.toml | 11 +++- 8 files changed, 347 insertions(+), 15 deletions(-) create mode 100644 map2map/common_args.yaml create mode 100644 map2map/cropper.py create mode 100644 map2map/test_args.yaml create mode 100644 map2map/train_args.yaml diff --git a/.gitignore b/.gitignore index 894a44c..36e8838 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +*.swp # C extensions *.so diff --git a/README.md b/README.md index bcf42b5..fa2875b 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,11 @@ pip install -e . ## Usage -The command is `m2m.py` in your `$PATH` after installation. +The command is `m2m` in your `$PATH` after installation. Take a look at the examples in `scripts/*.slurm`. -For all command line options look at `map2map/args.py` or do `m2m.py -h`. +For all command line options look at the `map2map/*args.yaml` or do `m2m --help`, and `m2m train --help` or `m2m test --help`. +Another tool is the map cropper. It can take a single 3d field from a simulation and make little tiles extracted randomly from the main simulation. The training dataset is then saved in the target directory with the proper format for m2m. ### Data diff --git a/map2map/common_args.yaml b/map2map/common_args.yaml new file mode 100644 index 0000000..609f10e --- /dev/null +++ b/map2map/common_args.yaml @@ -0,0 +1,89 @@ +arguments: + 'in-norms': + type: str_list + help: 'comma-sep. list of input normalization functions' + 'tgt-norms': + type: str_list + help: 'comma-sep. list of target normalization functions' + 'crop': + type: int_tuple + help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers' + 'crop-start': + type: int_tuple + help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers' + 'crop-stop': + type: int_tuple + help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers' + 'crop-step': + type: int_tuple + help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' + 'in-pad': + default: 0 + type: int_tuple + help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' + 'tgt-pad': + default: 0 + type: int_tuple + help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad' + 'scale-factor': + default: 1 + type: int + help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution' + 'model': + type: str + required: true + help: '(generator) model' + 'criterion': + default: 'MSELoss' + type: str + help: 'loss function' + 'load-state': + default: ckpt_link + type: str + help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint' + 'load-state-non-strict': + # action: 'store_false' + help: 'allow incompatible keys when loading model states' + # dest: 'load_state_strict' + 'batch-size': + default: 0 + type: int + required: true + help: 'mini-batch size, per GPU in training or in total in testing' + 'loader-workers': + default: 8 + type: int + help: 'number of subprocesses per data loader. 0 to disable multiprocessing' + 'callback-at': + type: 'abspath' + help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority' + 'misc-kwargs': + default: '{}' + type: json + help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions' + +# arguments: +# - 'optimizer': +# default: 'Adam' +# type: str +# help: 'optimizer for training' +# - 'learning-rate': +# default: 0.001 +# type: float +# help: 'learning rate for training' +# - 'num-epochs': +# default: 100 +# type: int +# help: 'number of training epochs' +# - 'save-interval': +# default: 10 +# type: int +# help: 'interval for saving checkpoints during training' +# - 'log-interval': +# default: 10 +# type: int +# help: 'interval for logging training progress' +# - 'device': +# default: 'cuda' +# type: str +# help: 'device for training (cuda or cpu)' \ No newline at end of file diff --git a/map2map/cropper.py b/map2map/cropper.py new file mode 100644 index 0000000..52ad864 --- /dev/null +++ b/map2map/cropper.py @@ -0,0 +1,35 @@ +import click +import numpy as np +import h5py as h5 +import pathlib +from tqdm import tqdm + + +def _extract_3d_tile_periodic(arr, tile_size, start_index): + periodic_indices = map( + lambda a: a[0] + a[1], + zip(np.ogrid[:tile_size, :tile_size, :tile_size], start_index), + ) + periodic_indices = map( + lambda a: np.mod(a[0], a[1]), zip(periodic_indices, arr.shape) + ) + return arr[tuple(periodic_indices)] + + +@click.command() +@click.option("--input", required=True, type=click.Path(exists=True), help="Input file") +@click.option("--output", required=True, type=click.Path(), help="Output directory") +@click.option( + "--tiles", required=True, type=click.Tuple([int]), help="Size of the tiles" +) +@click.option("--fields", required=True, type=click.Tuple([str]), help="Fields to crop") +@click.option("--num_tiles", required=True, type=int, help="Number of tiles to crop") +def cropper(input, output, tiles, fields, num_tiles): + output = pathlib.PosixPath(output) + + with h5.File(input, mode="r") as f: + for i in tqdm(range(num_tiles)): + a, b, c = np.random.randint(0, high=1024, size=3) + for field in fields: + tile = _extract_3d_tile_periodic(f[field], Q, (a, b, c)) + np.save(output / "tiles" / field / "{:04d}.npy".format(i), tile) diff --git a/map2map/main.py b/map2map/main.py index a8fa1dc..a31643c 100644 --- a/map2map/main.py +++ b/map2map/main.py @@ -1,17 +1,116 @@ -from .args import get_args from . import train from . import test +import click +import os +import yaml +try: + from yaml import CLoader as Loader +except ImportError: + from yaml import Loader + +import importlib.resources +import json +from functools import partial + +def _load_resource_file(resource_path): + # Import the package + pkg_files = importlib.resources.files() / resource_path + with pkg_files.open() as file: + return file.read() # Read the file and return its content + +def _str_list(value): + return value.split(',') + +def _int_tuple(value): + t = value.split(',') + t = tuple(int(i) for i in t) + return t + +class VariadicType(click.ParamType): + """ + A custom parameter type for Click command-line interface. + + This class provides a way to define custom parameter types for Click commands. + It supports various types such as string, integer, float, JSON, and file paths. + + Args: + typename (str or dict): The name of the type or a dictionary specifying the type and options. + + Raises: + ValueError: If the typename is not recognized. + """ + + _mapper = { + "str_list": {"type": "string_list", "func": _str_list}, + "int_tuple": {"type": "int_tuple", "func": _int_tuple}, + "json": {"type": "json", "func": json.loads}, + "int": {"type": "int"}, + "float": {"type": "float"}, + "str": {"type": "str"}, + "abspath": {"type": "path", "func": os.path.abspath}, + } + + def __init__(self, typename): + if typename in self._mapper: + self._type = self._mapper[typename] + elif type(typename) == dict: + self._type = self._mapper[typename["type"]] + self.args = typename["opts"] + else: + raise ValueError(f"Unknown type: {typename}") + self._typename = typename + self.name = self._type["type"] + if "func" not in self._type: + self._type["func"] = eval(self._type['type']) + + def convert(self, value, param, ctx): + try: + return self.type(value) + except Exception as e: + self.fail(f"Could not parse {self._typename}: {e}", param, ctx) + +def _apply_options(options_file, f): + common_args = yaml.load(_load_resource_file(options_file), Loader=Loader) + common_args = common_args['arguments'] + + for arg in common_args: + argopt = common_args[arg] + if 'type' in argopt: + if type(argopt['type']) == dict and argopt['type']['type'] == 'choice': + argopt['type'] = click.Choice(argopt['type']['opts']) + else: + argopt['type'] = VariadicType(argopt['type']) + f = click.option(f'--{arg}', **argopt)(f) + else: + f = click.option(f'--{arg}', **argopt)(f) + + return f + +m2m_options=partial(_apply_options,"common_args.yaml") -def main(): +@click.group() +@click.option("--config", type=click.Path(), help="Path to config file") +@click.pass_context +def main(ctx, config): + if config is not None and os.path.exists(config): + with open(config, 'r') as f: + config = yaml.load(f.read(), Loader=Loader) + ctx.default_map = config - args = get_args() +# Make a class that provides access to dict with the attribute mechanism +class DictProxy: + def __init__(self, d): + self.__dict__ = d - if args.mode == 'train': - train.node_worker(args) - elif args.mode == 'test': - test.test(args) +@main.command() +@m2m_options +@partial(_apply_options, "train_args.yaml") +def train(**kwargs): + train.node_worker(DictProxy(kwargs)) - -if __name__ == '__main__': - main() +@main.command() +@m2m_options +@partial(_apply_options, "test_args.yaml") +def test(**kwargs): + test.test(DictProxy(kwargs)) diff --git a/map2map/test_args.yaml b/map2map/test_args.yaml new file mode 100644 index 0000000..4de0f97 --- /dev/null +++ b/map2map/test_args.yaml @@ -0,0 +1,16 @@ +arguments: + 'test-style-pattern': + type: str + required: true + help: glob pattern for test data styles + 'test-in-patterns': + type: str_list + required: true + help: comma-sep. list of glob patterns for test input data + 'test-tgt-patterns': + type: str_list + required: true + help: comma-sep. list of glob patterns for test target data + 'num-threads': + type: int + help: number of CPU threads when cuda is unavailable. Default is the number of CPUs on the node by slurm diff --git a/map2map/train_args.yaml b/map2map/train_args.yaml new file mode 100644 index 0000000..4193a0c --- /dev/null +++ b/map2map/train_args.yaml @@ -0,0 +1,86 @@ +arguments: + 'train-style-pattern': + type: str + required: true + help: 'glob pattern for training data styles' + 'train-in-patterns': + type: str_list + required: true + help: 'comma-sep. list of glob patterns for training input data' + 'train-tgt-patterns': + type: str_list + required: true + help: 'comma-sep. list of glob patterns for training target data' + 'val-style-pattern': + type: str + help: 'glob pattern for validation data styles' + 'val-in-patterns': + type: str_list + help: 'comma-sep. list of glob patterns for validation input data' + 'val-tgt-patterns': + type: str_list + help: 'comma-sep. list of glob patterns for validation target data' + 'augment': + is_flag: true + help: 'enable data augmentation of axis flipping and permutation' + 'aug-shift': + type: int_tuple + help: 'data augmentation by shifting cropping by [0, aug_shift) pixels, useful for models that treat neighboring pixels differently, e.g. with strided convolutions. Comma-sep. list of 1 or d integers' + 'aug-add': + type: float + help: 'additive data augmentation, (normal) std, same factor for all fields' + 'aug-mul': + type: float + help: 'multiplicative data augmentation, (log-normal) std, same factor for all fields' + 'optimizer': + default: 'Adam' + type: str + help: 'optimization algorithm' + 'lr': + type: float + required: true + help: 'initial learning rate' + 'optimizer-args': + default: '{}' + type: json + help: "optimizer arguments in addition to the learning rate, e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'" + 'reduce-lr-on-plateau': + is_flag: true + help: 'Enable ReduceLROnPlateau learning rate scheduler' + 'scheduler-args': + default: '{"verbose": true}' + type: json + help: 'arguments for the ReduceLROnPlateau scheduler' + 'init-weight-std': + type: float + help: 'weight initialization std' + 'epochs': + default: 128 + type: int + help: 'total number of epochs to run' + 'seed': + default: 42 + type: int + help: 'seed for initializing training' + 'div-data': + is_flag: true + help: 'enable data division among GPUs for better page caching. Data division is shuffled every epoch. Only relevant if there are multiple crops in each field' + 'div-shuffle-dist': + default: 1 + type: float + help: 'distance to further shuffle cropped samples relative to their fields, to be used with --div-data. Only relevant if there are multiple crops in each file. The order of each sample is randomly displaced by this value. Setting it to 0 turn off this randomization, and setting it to N limits the shuffling within a distance of N files. Change this to balance cache locality and stochasticity' + 'dist-backend': + default: 'nccl' + type: + type: "choice" + opts: + - 'gloo' + - 'nccl' + help: 'distributed backend' + 'log-interval': + default: 100 + type: int + help: 'interval (batches) between logging training loss' + 'detect-anomaly': + is_flag: true + help: 'enable anomaly detection for the autograd engine' \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 10827f1..5690fda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ numpy = "^1.26.4" scipy = "^1.13.0" matplotlib = "^3.9.0" tensorboard = "^2.16.2" +click = "^8.1.7" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] python-semantic-release = "^9.7.3" @@ -42,7 +44,9 @@ dependencies = [ 'numpy', 'scipy', 'matplotlib', - 'tensorboard'] + 'tensorboard', + 'h5py','tqdm', + 'click','pyyaml'] authors = [ {name = "Yin Li", email = "eelregit@gmail.com"}, @@ -54,10 +58,11 @@ maintainers = [ ] [project.scripts] -m2m = "map2map:main" +m2m = "map2map:main.main" +mapcropper = "map2map:cropper.cropper" [tool.poetry.scripts] -map2map = "map2map:main" +map2map = "map2map:main.main" [project.urls]