diff --git a/.gitignore b/.gitignore index 36e8838..894a44c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ __pycache__/ *.py[cod] *$py.class -*.swp # C extensions *.so diff --git a/README.md b/README.md index fa2875b..bcf42b5 100644 --- a/README.md +++ b/README.md @@ -27,11 +27,10 @@ pip install -e . ## Usage -The command is `m2m` in your `$PATH` after installation. +The command is `m2m.py` in your `$PATH` after installation. Take a look at the examples in `scripts/*.slurm`. -For all command line options look at the `map2map/*args.yaml` or do `m2m --help`, and `m2m train --help` or `m2m test --help`. +For all command line options look at `map2map/args.py` or do `m2m.py -h`. -Another tool is the map cropper. It can take a single 3d field from a simulation and make little tiles extracted randomly from the main simulation. The training dataset is then saved in the target directory with the proper format for m2m. ### Data diff --git a/map2map/common_args.yaml b/map2map/common_args.yaml deleted file mode 100644 index 609f10e..0000000 --- a/map2map/common_args.yaml +++ /dev/null @@ -1,89 +0,0 @@ -arguments: - 'in-norms': - type: str_list - help: 'comma-sep. list of input normalization functions' - 'tgt-norms': - type: str_list - help: 'comma-sep. list of target normalization functions' - 'crop': - type: int_tuple - help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers' - 'crop-start': - type: int_tuple - help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers' - 'crop-stop': - type: int_tuple - help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers' - 'crop-step': - type: int_tuple - help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' - 'in-pad': - default: 0 - type: int_tuple - help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' - 'tgt-pad': - default: 0 - type: int_tuple - help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad' - 'scale-factor': - default: 1 - type: int - help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution' - 'model': - type: str - required: true - help: '(generator) model' - 'criterion': - default: 'MSELoss' - type: str - help: 'loss function' - 'load-state': - default: ckpt_link - type: str - help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint' - 'load-state-non-strict': - # action: 'store_false' - help: 'allow incompatible keys when loading model states' - # dest: 'load_state_strict' - 'batch-size': - default: 0 - type: int - required: true - help: 'mini-batch size, per GPU in training or in total in testing' - 'loader-workers': - default: 8 - type: int - help: 'number of subprocesses per data loader. 0 to disable multiprocessing' - 'callback-at': - type: 'abspath' - help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority' - 'misc-kwargs': - default: '{}' - type: json - help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions' - -# arguments: -# - 'optimizer': -# default: 'Adam' -# type: str -# help: 'optimizer for training' -# - 'learning-rate': -# default: 0.001 -# type: float -# help: 'learning rate for training' -# - 'num-epochs': -# default: 100 -# type: int -# help: 'number of training epochs' -# - 'save-interval': -# default: 10 -# type: int -# help: 'interval for saving checkpoints during training' -# - 'log-interval': -# default: 10 -# type: int -# help: 'interval for logging training progress' -# - 'device': -# default: 'cuda' -# type: str -# help: 'device for training (cuda or cpu)' \ No newline at end of file diff --git a/map2map/cropper.py b/map2map/cropper.py deleted file mode 100644 index 52ad864..0000000 --- a/map2map/cropper.py +++ /dev/null @@ -1,35 +0,0 @@ -import click -import numpy as np -import h5py as h5 -import pathlib -from tqdm import tqdm - - -def _extract_3d_tile_periodic(arr, tile_size, start_index): - periodic_indices = map( - lambda a: a[0] + a[1], - zip(np.ogrid[:tile_size, :tile_size, :tile_size], start_index), - ) - periodic_indices = map( - lambda a: np.mod(a[0], a[1]), zip(periodic_indices, arr.shape) - ) - return arr[tuple(periodic_indices)] - - -@click.command() -@click.option("--input", required=True, type=click.Path(exists=True), help="Input file") -@click.option("--output", required=True, type=click.Path(), help="Output directory") -@click.option( - "--tiles", required=True, type=click.Tuple([int]), help="Size of the tiles" -) -@click.option("--fields", required=True, type=click.Tuple([str]), help="Fields to crop") -@click.option("--num_tiles", required=True, type=int, help="Number of tiles to crop") -def cropper(input, output, tiles, fields, num_tiles): - output = pathlib.PosixPath(output) - - with h5.File(input, mode="r") as f: - for i in tqdm(range(num_tiles)): - a, b, c = np.random.randint(0, high=1024, size=3) - for field in fields: - tile = _extract_3d_tile_periodic(f[field], Q, (a, b, c)) - np.save(output / "tiles" / field / "{:04d}.npy".format(i), tile) diff --git a/map2map/main.py b/map2map/main.py index a31643c..a8fa1dc 100644 --- a/map2map/main.py +++ b/map2map/main.py @@ -1,116 +1,17 @@ +from .args import get_args from . import train from . import test -import click -import os -import yaml -try: - from yaml import CLoader as Loader -except ImportError: - from yaml import Loader - -import importlib.resources -import json -from functools import partial - -def _load_resource_file(resource_path): - # Import the package - pkg_files = importlib.resources.files() / resource_path - with pkg_files.open() as file: - return file.read() # Read the file and return its content - -def _str_list(value): - return value.split(',') - -def _int_tuple(value): - t = value.split(',') - t = tuple(int(i) for i in t) - return t - -class VariadicType(click.ParamType): - """ - A custom parameter type for Click command-line interface. - - This class provides a way to define custom parameter types for Click commands. - It supports various types such as string, integer, float, JSON, and file paths. - - Args: - typename (str or dict): The name of the type or a dictionary specifying the type and options. - - Raises: - ValueError: If the typename is not recognized. - """ - - _mapper = { - "str_list": {"type": "string_list", "func": _str_list}, - "int_tuple": {"type": "int_tuple", "func": _int_tuple}, - "json": {"type": "json", "func": json.loads}, - "int": {"type": "int"}, - "float": {"type": "float"}, - "str": {"type": "str"}, - "abspath": {"type": "path", "func": os.path.abspath}, - } - - def __init__(self, typename): - if typename in self._mapper: - self._type = self._mapper[typename] - elif type(typename) == dict: - self._type = self._mapper[typename["type"]] - self.args = typename["opts"] - else: - raise ValueError(f"Unknown type: {typename}") - self._typename = typename - self.name = self._type["type"] - if "func" not in self._type: - self._type["func"] = eval(self._type['type']) - - def convert(self, value, param, ctx): - try: - return self.type(value) - except Exception as e: - self.fail(f"Could not parse {self._typename}: {e}", param, ctx) - -def _apply_options(options_file, f): - common_args = yaml.load(_load_resource_file(options_file), Loader=Loader) - common_args = common_args['arguments'] - - for arg in common_args: - argopt = common_args[arg] - if 'type' in argopt: - if type(argopt['type']) == dict and argopt['type']['type'] == 'choice': - argopt['type'] = click.Choice(argopt['type']['opts']) - else: - argopt['type'] = VariadicType(argopt['type']) - f = click.option(f'--{arg}', **argopt)(f) - else: - f = click.option(f'--{arg}', **argopt)(f) - - return f - -m2m_options=partial(_apply_options,"common_args.yaml") -@click.group() -@click.option("--config", type=click.Path(), help="Path to config file") -@click.pass_context -def main(ctx, config): - if config is not None and os.path.exists(config): - with open(config, 'r') as f: - config = yaml.load(f.read(), Loader=Loader) - ctx.default_map = config +def main(): -# Make a class that provides access to dict with the attribute mechanism -class DictProxy: - def __init__(self, d): - self.__dict__ = d + args = get_args() -@main.command() -@m2m_options -@partial(_apply_options, "train_args.yaml") -def train(**kwargs): - train.node_worker(DictProxy(kwargs)) + if args.mode == 'train': + train.node_worker(args) + elif args.mode == 'test': + test.test(args) -@main.command() -@m2m_options -@partial(_apply_options, "test_args.yaml") -def test(**kwargs): - test.test(DictProxy(kwargs)) + +if __name__ == '__main__': + main() diff --git a/map2map/test_args.yaml b/map2map/test_args.yaml deleted file mode 100644 index 4de0f97..0000000 --- a/map2map/test_args.yaml +++ /dev/null @@ -1,16 +0,0 @@ -arguments: - 'test-style-pattern': - type: str - required: true - help: glob pattern for test data styles - 'test-in-patterns': - type: str_list - required: true - help: comma-sep. list of glob patterns for test input data - 'test-tgt-patterns': - type: str_list - required: true - help: comma-sep. list of glob patterns for test target data - 'num-threads': - type: int - help: number of CPU threads when cuda is unavailable. Default is the number of CPUs on the node by slurm diff --git a/map2map/train_args.yaml b/map2map/train_args.yaml deleted file mode 100644 index 4193a0c..0000000 --- a/map2map/train_args.yaml +++ /dev/null @@ -1,86 +0,0 @@ -arguments: - 'train-style-pattern': - type: str - required: true - help: 'glob pattern for training data styles' - 'train-in-patterns': - type: str_list - required: true - help: 'comma-sep. list of glob patterns for training input data' - 'train-tgt-patterns': - type: str_list - required: true - help: 'comma-sep. list of glob patterns for training target data' - 'val-style-pattern': - type: str - help: 'glob pattern for validation data styles' - 'val-in-patterns': - type: str_list - help: 'comma-sep. list of glob patterns for validation input data' - 'val-tgt-patterns': - type: str_list - help: 'comma-sep. list of glob patterns for validation target data' - 'augment': - is_flag: true - help: 'enable data augmentation of axis flipping and permutation' - 'aug-shift': - type: int_tuple - help: 'data augmentation by shifting cropping by [0, aug_shift) pixels, useful for models that treat neighboring pixels differently, e.g. with strided convolutions. Comma-sep. list of 1 or d integers' - 'aug-add': - type: float - help: 'additive data augmentation, (normal) std, same factor for all fields' - 'aug-mul': - type: float - help: 'multiplicative data augmentation, (log-normal) std, same factor for all fields' - 'optimizer': - default: 'Adam' - type: str - help: 'optimization algorithm' - 'lr': - type: float - required: true - help: 'initial learning rate' - 'optimizer-args': - default: '{}' - type: json - help: "optimizer arguments in addition to the learning rate, e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'" - 'reduce-lr-on-plateau': - is_flag: true - help: 'Enable ReduceLROnPlateau learning rate scheduler' - 'scheduler-args': - default: '{"verbose": true}' - type: json - help: 'arguments for the ReduceLROnPlateau scheduler' - 'init-weight-std': - type: float - help: 'weight initialization std' - 'epochs': - default: 128 - type: int - help: 'total number of epochs to run' - 'seed': - default: 42 - type: int - help: 'seed for initializing training' - 'div-data': - is_flag: true - help: 'enable data division among GPUs for better page caching. Data division is shuffled every epoch. Only relevant if there are multiple crops in each field' - 'div-shuffle-dist': - default: 1 - type: float - help: 'distance to further shuffle cropped samples relative to their fields, to be used with --div-data. Only relevant if there are multiple crops in each file. The order of each sample is randomly displaced by this value. Setting it to 0 turn off this randomization, and setting it to N limits the shuffling within a distance of N files. Change this to balance cache locality and stochasticity' - 'dist-backend': - default: 'nccl' - type: - type: "choice" - opts: - - 'gloo' - - 'nccl' - help: 'distributed backend' - 'log-interval': - default: 100 - type: int - help: 'interval (batches) between logging training loss' - 'detect-anomaly': - is_flag: true - help: 'enable anomaly detection for the autograd engine' \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5690fda..10827f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,6 @@ numpy = "^1.26.4" scipy = "^1.13.0" matplotlib = "^3.9.0" tensorboard = "^2.16.2" -click = "^8.1.7" -pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] python-semantic-release = "^9.7.3" @@ -44,9 +42,7 @@ dependencies = [ 'numpy', 'scipy', 'matplotlib', - 'tensorboard', - 'h5py','tqdm', - 'click','pyyaml'] + 'tensorboard'] authors = [ {name = "Yin Li", email = "eelregit@gmail.com"}, @@ -58,11 +54,10 @@ maintainers = [ ] [project.scripts] -m2m = "map2map:main.main" -mapcropper = "map2map:cropper.cropper" +m2m = "map2map:main" [tool.poetry.scripts] -map2map = "map2map:main.main" +map2map = "map2map:main" [project.urls]