feat: Use click and autogenerated arguments from YAML parameter file
BREAKING CHANGE: arguments may not be propagated exactly the same way as previously
This commit is contained in:
parent
661b17eea1
commit
47cdaffd81
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
*.swp
|
||||||
|
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
@ -27,10 +27,11 @@ pip install -e .
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
The command is `m2m.py` in your `$PATH` after installation.
|
The command is `m2m` in your `$PATH` after installation.
|
||||||
Take a look at the examples in `scripts/*.slurm`.
|
Take a look at the examples in `scripts/*.slurm`.
|
||||||
For all command line options look at `map2map/args.py` or do `m2m.py -h`.
|
For all command line options look at the `map2map/*args.yaml` or do `m2m --help`, and `m2m train --help` or `m2m test --help`.
|
||||||
|
|
||||||
|
Another tool is the map cropper. It can take a single 3d field from a simulation and make little tiles extracted randomly from the main simulation. The training dataset is then saved in the target directory with the proper format for m2m.
|
||||||
|
|
||||||
### Data
|
### Data
|
||||||
|
|
||||||
|
89
map2map/common_args.yaml
Normal file
89
map2map/common_args.yaml
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
arguments:
|
||||||
|
'in-norms':
|
||||||
|
type: str_list
|
||||||
|
help: 'comma-sep. list of input normalization functions'
|
||||||
|
'tgt-norms':
|
||||||
|
type: str_list
|
||||||
|
help: 'comma-sep. list of target normalization functions'
|
||||||
|
'crop':
|
||||||
|
type: int_tuple
|
||||||
|
help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers'
|
||||||
|
'crop-start':
|
||||||
|
type: int_tuple
|
||||||
|
help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers'
|
||||||
|
'crop-stop':
|
||||||
|
type: int_tuple
|
||||||
|
help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers'
|
||||||
|
'crop-step':
|
||||||
|
type: int_tuple
|
||||||
|
help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers'
|
||||||
|
'in-pad':
|
||||||
|
default: 0
|
||||||
|
type: int_tuple
|
||||||
|
help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively'
|
||||||
|
'tgt-pad':
|
||||||
|
default: 0
|
||||||
|
type: int_tuple
|
||||||
|
help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad'
|
||||||
|
'scale-factor':
|
||||||
|
default: 1
|
||||||
|
type: int
|
||||||
|
help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution'
|
||||||
|
'model':
|
||||||
|
type: str
|
||||||
|
required: true
|
||||||
|
help: '(generator) model'
|
||||||
|
'criterion':
|
||||||
|
default: 'MSELoss'
|
||||||
|
type: str
|
||||||
|
help: 'loss function'
|
||||||
|
'load-state':
|
||||||
|
default: ckpt_link
|
||||||
|
type: str
|
||||||
|
help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint'
|
||||||
|
'load-state-non-strict':
|
||||||
|
# action: 'store_false'
|
||||||
|
help: 'allow incompatible keys when loading model states'
|
||||||
|
# dest: 'load_state_strict'
|
||||||
|
'batch-size':
|
||||||
|
default: 0
|
||||||
|
type: int
|
||||||
|
required: true
|
||||||
|
help: 'mini-batch size, per GPU in training or in total in testing'
|
||||||
|
'loader-workers':
|
||||||
|
default: 8
|
||||||
|
type: int
|
||||||
|
help: 'number of subprocesses per data loader. 0 to disable multiprocessing'
|
||||||
|
'callback-at':
|
||||||
|
type: 'abspath'
|
||||||
|
help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority'
|
||||||
|
'misc-kwargs':
|
||||||
|
default: '{}'
|
||||||
|
type: json
|
||||||
|
help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions'
|
||||||
|
|
||||||
|
# arguments:
|
||||||
|
# - 'optimizer':
|
||||||
|
# default: 'Adam'
|
||||||
|
# type: str
|
||||||
|
# help: 'optimizer for training'
|
||||||
|
# - 'learning-rate':
|
||||||
|
# default: 0.001
|
||||||
|
# type: float
|
||||||
|
# help: 'learning rate for training'
|
||||||
|
# - 'num-epochs':
|
||||||
|
# default: 100
|
||||||
|
# type: int
|
||||||
|
# help: 'number of training epochs'
|
||||||
|
# - 'save-interval':
|
||||||
|
# default: 10
|
||||||
|
# type: int
|
||||||
|
# help: 'interval for saving checkpoints during training'
|
||||||
|
# - 'log-interval':
|
||||||
|
# default: 10
|
||||||
|
# type: int
|
||||||
|
# help: 'interval for logging training progress'
|
||||||
|
# - 'device':
|
||||||
|
# default: 'cuda'
|
||||||
|
# type: str
|
||||||
|
# help: 'device for training (cuda or cpu)'
|
35
map2map/cropper.py
Normal file
35
map2map/cropper.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import click
|
||||||
|
import numpy as np
|
||||||
|
import h5py as h5
|
||||||
|
import pathlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_3d_tile_periodic(arr, tile_size, start_index):
|
||||||
|
periodic_indices = map(
|
||||||
|
lambda a: a[0] + a[1],
|
||||||
|
zip(np.ogrid[:tile_size, :tile_size, :tile_size], start_index),
|
||||||
|
)
|
||||||
|
periodic_indices = map(
|
||||||
|
lambda a: np.mod(a[0], a[1]), zip(periodic_indices, arr.shape)
|
||||||
|
)
|
||||||
|
return arr[tuple(periodic_indices)]
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.option("--input", required=True, type=click.Path(exists=True), help="Input file")
|
||||||
|
@click.option("--output", required=True, type=click.Path(), help="Output directory")
|
||||||
|
@click.option(
|
||||||
|
"--tiles", required=True, type=click.Tuple([int]), help="Size of the tiles"
|
||||||
|
)
|
||||||
|
@click.option("--fields", required=True, type=click.Tuple([str]), help="Fields to crop")
|
||||||
|
@click.option("--num_tiles", required=True, type=int, help="Number of tiles to crop")
|
||||||
|
def cropper(input, output, tiles, fields, num_tiles):
|
||||||
|
output = pathlib.PosixPath(output)
|
||||||
|
|
||||||
|
with h5.File(input, mode="r") as f:
|
||||||
|
for i in tqdm(range(num_tiles)):
|
||||||
|
a, b, c = np.random.randint(0, high=1024, size=3)
|
||||||
|
for field in fields:
|
||||||
|
tile = _extract_3d_tile_periodic(f[field], Q, (a, b, c))
|
||||||
|
np.save(output / "tiles" / field / "{:04d}.npy".format(i), tile)
|
119
map2map/main.py
119
map2map/main.py
@ -1,17 +1,116 @@
|
|||||||
from .args import get_args
|
|
||||||
from . import train
|
from . import train
|
||||||
from . import test
|
from . import test
|
||||||
|
import click
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
try:
|
||||||
|
from yaml import CLoader as Loader
|
||||||
|
except ImportError:
|
||||||
|
from yaml import Loader
|
||||||
|
|
||||||
|
import importlib.resources
|
||||||
|
import json
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
def _load_resource_file(resource_path):
|
||||||
|
# Import the package
|
||||||
|
pkg_files = importlib.resources.files() / resource_path
|
||||||
|
with pkg_files.open() as file:
|
||||||
|
return file.read() # Read the file and return its content
|
||||||
|
|
||||||
|
def _str_list(value):
|
||||||
|
return value.split(',')
|
||||||
|
|
||||||
|
def _int_tuple(value):
|
||||||
|
t = value.split(',')
|
||||||
|
t = tuple(int(i) for i in t)
|
||||||
|
return t
|
||||||
|
|
||||||
|
class VariadicType(click.ParamType):
|
||||||
|
"""
|
||||||
|
A custom parameter type for Click command-line interface.
|
||||||
|
|
||||||
|
This class provides a way to define custom parameter types for Click commands.
|
||||||
|
It supports various types such as string, integer, float, JSON, and file paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
typename (str or dict): The name of the type or a dictionary specifying the type and options.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the typename is not recognized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_mapper = {
|
||||||
|
"str_list": {"type": "string_list", "func": _str_list},
|
||||||
|
"int_tuple": {"type": "int_tuple", "func": _int_tuple},
|
||||||
|
"json": {"type": "json", "func": json.loads},
|
||||||
|
"int": {"type": "int"},
|
||||||
|
"float": {"type": "float"},
|
||||||
|
"str": {"type": "str"},
|
||||||
|
"abspath": {"type": "path", "func": os.path.abspath},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, typename):
|
||||||
|
if typename in self._mapper:
|
||||||
|
self._type = self._mapper[typename]
|
||||||
|
elif type(typename) == dict:
|
||||||
|
self._type = self._mapper[typename["type"]]
|
||||||
|
self.args = typename["opts"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown type: {typename}")
|
||||||
|
self._typename = typename
|
||||||
|
self.name = self._type["type"]
|
||||||
|
if "func" not in self._type:
|
||||||
|
self._type["func"] = eval(self._type['type'])
|
||||||
|
|
||||||
|
def convert(self, value, param, ctx):
|
||||||
|
try:
|
||||||
|
return self.type(value)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Could not parse {self._typename}: {e}", param, ctx)
|
||||||
|
|
||||||
|
def _apply_options(options_file, f):
|
||||||
|
common_args = yaml.load(_load_resource_file(options_file), Loader=Loader)
|
||||||
|
common_args = common_args['arguments']
|
||||||
|
|
||||||
|
for arg in common_args:
|
||||||
|
argopt = common_args[arg]
|
||||||
|
if 'type' in argopt:
|
||||||
|
if type(argopt['type']) == dict and argopt['type']['type'] == 'choice':
|
||||||
|
argopt['type'] = click.Choice(argopt['type']['opts'])
|
||||||
|
else:
|
||||||
|
argopt['type'] = VariadicType(argopt['type'])
|
||||||
|
f = click.option(f'--{arg}', **argopt)(f)
|
||||||
|
else:
|
||||||
|
f = click.option(f'--{arg}', **argopt)(f)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
m2m_options=partial(_apply_options,"common_args.yaml")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
@click.group()
|
||||||
|
@click.option("--config", type=click.Path(), help="Path to config file")
|
||||||
|
@click.pass_context
|
||||||
|
def main(ctx, config):
|
||||||
|
if config is not None and os.path.exists(config):
|
||||||
|
with open(config, 'r') as f:
|
||||||
|
config = yaml.load(f.read(), Loader=Loader)
|
||||||
|
ctx.default_map = config
|
||||||
|
|
||||||
args = get_args()
|
# Make a class that provides access to dict with the attribute mechanism
|
||||||
|
class DictProxy:
|
||||||
|
def __init__(self, d):
|
||||||
|
self.__dict__ = d
|
||||||
|
|
||||||
if args.mode == 'train':
|
@main.command()
|
||||||
train.node_worker(args)
|
@m2m_options
|
||||||
elif args.mode == 'test':
|
@partial(_apply_options, "train_args.yaml")
|
||||||
test.test(args)
|
def train(**kwargs):
|
||||||
|
train.node_worker(DictProxy(kwargs))
|
||||||
|
|
||||||
|
@main.command()
|
||||||
if __name__ == '__main__':
|
@m2m_options
|
||||||
main()
|
@partial(_apply_options, "test_args.yaml")
|
||||||
|
def test(**kwargs):
|
||||||
|
test.test(DictProxy(kwargs))
|
||||||
|
16
map2map/test_args.yaml
Normal file
16
map2map/test_args.yaml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
arguments:
|
||||||
|
'test-style-pattern':
|
||||||
|
type: str
|
||||||
|
required: true
|
||||||
|
help: glob pattern for test data styles
|
||||||
|
'test-in-patterns':
|
||||||
|
type: str_list
|
||||||
|
required: true
|
||||||
|
help: comma-sep. list of glob patterns for test input data
|
||||||
|
'test-tgt-patterns':
|
||||||
|
type: str_list
|
||||||
|
required: true
|
||||||
|
help: comma-sep. list of glob patterns for test target data
|
||||||
|
'num-threads':
|
||||||
|
type: int
|
||||||
|
help: number of CPU threads when cuda is unavailable. Default is the number of CPUs on the node by slurm
|
86
map2map/train_args.yaml
Normal file
86
map2map/train_args.yaml
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
arguments:
|
||||||
|
'train-style-pattern':
|
||||||
|
type: str
|
||||||
|
required: true
|
||||||
|
help: 'glob pattern for training data styles'
|
||||||
|
'train-in-patterns':
|
||||||
|
type: str_list
|
||||||
|
required: true
|
||||||
|
help: 'comma-sep. list of glob patterns for training input data'
|
||||||
|
'train-tgt-patterns':
|
||||||
|
type: str_list
|
||||||
|
required: true
|
||||||
|
help: 'comma-sep. list of glob patterns for training target data'
|
||||||
|
'val-style-pattern':
|
||||||
|
type: str
|
||||||
|
help: 'glob pattern for validation data styles'
|
||||||
|
'val-in-patterns':
|
||||||
|
type: str_list
|
||||||
|
help: 'comma-sep. list of glob patterns for validation input data'
|
||||||
|
'val-tgt-patterns':
|
||||||
|
type: str_list
|
||||||
|
help: 'comma-sep. list of glob patterns for validation target data'
|
||||||
|
'augment':
|
||||||
|
is_flag: true
|
||||||
|
help: 'enable data augmentation of axis flipping and permutation'
|
||||||
|
'aug-shift':
|
||||||
|
type: int_tuple
|
||||||
|
help: 'data augmentation by shifting cropping by [0, aug_shift) pixels, useful for models that treat neighboring pixels differently, e.g. with strided convolutions. Comma-sep. list of 1 or d integers'
|
||||||
|
'aug-add':
|
||||||
|
type: float
|
||||||
|
help: 'additive data augmentation, (normal) std, same factor for all fields'
|
||||||
|
'aug-mul':
|
||||||
|
type: float
|
||||||
|
help: 'multiplicative data augmentation, (log-normal) std, same factor for all fields'
|
||||||
|
'optimizer':
|
||||||
|
default: 'Adam'
|
||||||
|
type: str
|
||||||
|
help: 'optimization algorithm'
|
||||||
|
'lr':
|
||||||
|
type: float
|
||||||
|
required: true
|
||||||
|
help: 'initial learning rate'
|
||||||
|
'optimizer-args':
|
||||||
|
default: '{}'
|
||||||
|
type: json
|
||||||
|
help: "optimizer arguments in addition to the learning rate, e.g. --optimizer-args '{\"betas\": [0.5, 0.9]}'"
|
||||||
|
'reduce-lr-on-plateau':
|
||||||
|
is_flag: true
|
||||||
|
help: 'Enable ReduceLROnPlateau learning rate scheduler'
|
||||||
|
'scheduler-args':
|
||||||
|
default: '{"verbose": true}'
|
||||||
|
type: json
|
||||||
|
help: 'arguments for the ReduceLROnPlateau scheduler'
|
||||||
|
'init-weight-std':
|
||||||
|
type: float
|
||||||
|
help: 'weight initialization std'
|
||||||
|
'epochs':
|
||||||
|
default: 128
|
||||||
|
type: int
|
||||||
|
help: 'total number of epochs to run'
|
||||||
|
'seed':
|
||||||
|
default: 42
|
||||||
|
type: int
|
||||||
|
help: 'seed for initializing training'
|
||||||
|
'div-data':
|
||||||
|
is_flag: true
|
||||||
|
help: 'enable data division among GPUs for better page caching. Data division is shuffled every epoch. Only relevant if there are multiple crops in each field'
|
||||||
|
'div-shuffle-dist':
|
||||||
|
default: 1
|
||||||
|
type: float
|
||||||
|
help: 'distance to further shuffle cropped samples relative to their fields, to be used with --div-data. Only relevant if there are multiple crops in each file. The order of each sample is randomly displaced by this value. Setting it to 0 turn off this randomization, and setting it to N limits the shuffling within a distance of N files. Change this to balance cache locality and stochasticity'
|
||||||
|
'dist-backend':
|
||||||
|
default: 'nccl'
|
||||||
|
type:
|
||||||
|
type: "choice"
|
||||||
|
opts:
|
||||||
|
- 'gloo'
|
||||||
|
- 'nccl'
|
||||||
|
help: 'distributed backend'
|
||||||
|
'log-interval':
|
||||||
|
default: 100
|
||||||
|
type: int
|
||||||
|
help: 'interval (batches) between logging training loss'
|
||||||
|
'detect-anomaly':
|
||||||
|
is_flag: true
|
||||||
|
help: 'enable anomaly detection for the autograd engine'
|
@ -17,6 +17,8 @@ numpy = "^1.26.4"
|
|||||||
scipy = "^1.13.0"
|
scipy = "^1.13.0"
|
||||||
matplotlib = "^3.9.0"
|
matplotlib = "^3.9.0"
|
||||||
tensorboard = "^2.16.2"
|
tensorboard = "^2.16.2"
|
||||||
|
click = "^8.1.7"
|
||||||
|
pyyaml = "^6.0.1"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
python-semantic-release = "^9.7.3"
|
python-semantic-release = "^9.7.3"
|
||||||
@ -42,7 +44,9 @@ dependencies = [
|
|||||||
'numpy',
|
'numpy',
|
||||||
'scipy',
|
'scipy',
|
||||||
'matplotlib',
|
'matplotlib',
|
||||||
'tensorboard']
|
'tensorboard',
|
||||||
|
'h5py','tqdm',
|
||||||
|
'click','pyyaml']
|
||||||
|
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Yin Li", email = "eelregit@gmail.com"},
|
{name = "Yin Li", email = "eelregit@gmail.com"},
|
||||||
@ -54,10 +58,11 @@ maintainers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
m2m = "map2map:main"
|
m2m = "map2map:main.main"
|
||||||
|
mapcropper = "map2map:cropper.cropper"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
map2map = "map2map:main"
|
map2map = "map2map:main.main"
|
||||||
|
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
Loading…
Reference in New Issue
Block a user