Add much more flexible argument handling

This commit is contained in:
Guilhem Lavaux 2024-04-03 17:00:41 +02:00
parent 6bf93a9835
commit 813a4e95ee
3 changed files with 118 additions and 81 deletions

View File

@ -1,88 +1,88 @@
arguments: arguments:
- 'in-norms': - 'in-norms':
type: str_list type: str_list
help: 'comma-sep. list of input normalization functions' help: 'comma-sep. list of input normalization functions'
- 'tgt-norms': - 'tgt-norms':
type: str_list type: str_list
help: 'comma-sep. list of target normalization functions' help: 'comma-sep. list of target normalization functions'
- 'crop': - 'crop':
type: int_tuple type: int_tuple
help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers' help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers'
- 'crop-start': - 'crop-start':
type: int_tuple type: int_tuple
help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers' help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers'
- 'crop-stop': - 'crop-stop':
type: int_tuple type: int_tuple
help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers' help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers'
- 'crop-step': - 'crop-step':
type: int_tuple type: int_tuple
help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers'
- 'in-pad': - 'in-pad':
'pad': 0 pad: 0
type: int_tuple type: int_tuple
help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively'
- 'tgt-pad': - 'tgt-pad':
default: 0 default: 0
type: int_tuple type: int_tuple
help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad' help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad'
- 'scale-factor': - 'scale-factor':
default: 1 default: 1
type: int type: int
help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution' help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution'
- 'model': - 'model':
type: str type: str
required: true required: true
help: '(generator) model' help: '(generator) model'
- 'criterion': - 'criterion':
default: 'MSELoss' default: 'MSELoss'
type: str type: str
help: 'loss function' help: 'loss function'
- 'load-state': - 'load-state':
default: ckpt_link default: ckpt_link
type: str type: str
help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint' help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint'
- 'load-state-non-strict': - 'load-state-non-strict':
action: 'store_false' action: 'store_false'
help: 'allow incompatible keys when loading model states' help: 'allow incompatible keys when loading model states'
dest: 'load_state_strict' dest: 'load_state_strict'
- 'batch-size': - 'batch-size':
'batches': 0 'batches': 0
type: int type: int
required: true required: true
help: 'mini-batch size, per GPU in training or in total in testing' help: 'mini-batch size, per GPU in training or in total in testing'
- 'loader-workers': - 'loader-workers':
default: 8 default: 8
type: int type: int
help: 'number of subprocesses per data loader. 0 to disable multiprocessing' help: 'number of subprocesses per data loader. 0 to disable multiprocessing'
- 'callback-at': - 'callback-at':
type: 'lambda s: os.path.abspath(s)' type: 'lambda s: os.path.abspath(s)'
help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority' help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority'
- 'misc-kwargs': - 'misc-kwargs':
default: '{}' default: '{}'
type: json.loads type: json.loads
help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions' help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions'
arguments: arguments:
- 'optimizer': - 'optimizer':
default: 'Adam' default: 'Adam'
type: str type: str
help: 'optimizer for training' help: 'optimizer for training'
- 'learning-rate': - 'learning-rate':
default: 0.001 default: 0.001
type: float type: float
help: 'learning rate for training' help: 'learning rate for training'
- 'num-epochs': - 'num-epochs':
default: 100 default: 100
type: int type: int
help: 'number of training epochs' help: 'number of training epochs'
- 'save-interval': - 'save-interval':
default: 10 default: 10
type: int type: int
help: 'interval for saving checkpoints during training' help: 'interval for saving checkpoints during training'
- 'log-interval': - 'log-interval':
default: 10 default: 10
type: int type: int
help: 'interval for logging training progress' help: 'interval for logging training progress'
- 'device': - 'device':
default: 'cuda' default: 'cuda'
type: str type: str
help: 'device for training (cuda or cpu)' help: 'device for training (cuda or cpu)'

View File

@ -1,17 +1,53 @@
from .args import get_args from .args import get_args
from . import train from . import train
from . import test from . import test
import click
import os
import yaml
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
import importlib.resources
def main(): def _load_resource_file(resource_path):
package = importlib.import_module('map2map') # Import the package
with importlib.resources.path('map2map', resource_path) as path:
return path.read_text() # Read the file and return its content
def str_list(s):
return s.split(',')
def m2m_options(f):
common_args = _load_resource_file('common_args.yaml')
for arg in common_args['arguments']:
argopt = common_args[arg]
if 'type' in argopt:
argopt['type'] = eval(argopt['type'])
f = click.option(f'--{arg}', **argopt)(f)
else:
f = click.option(f'--{arg}', **argopt)(f)
return f
@click.group()
@click.option("--config", type=click.Path())
@click.pass_context
def main(ctx, config):
if os.path.exists(config):
with open(config, 'r') as f:
config = yaml.load(f.read(), Loader=Loader)
ctx.default_map = config
@main.command()
@m2m_options
def train(**kwargs):
args = get_args() args = get_args()
train.node_worker(args)
if args.mode == 'train': @main.command()
train.node_worker(args) @m2m_options
elif args.mode == 'test': def test():
test.test(args) test.test(args)
if __name__ == '__main__':
main()

View File

@ -37,7 +37,8 @@ dependencies = [
'numpy', 'numpy',
'scipy', 'scipy',
'matplotlib', 'matplotlib',
'tensorboard'] 'tensorboard',
'click']
authors = [ authors = [
{name = "Yin Li", email = "eelregit@gmail.com"}, {name = "Yin Li", email = "eelregit@gmail.com"},
@ -49,7 +50,7 @@ maintainers = [
] ]
[project.scripts] [project.scripts]
m2m = "map2map:main" m2m = "map2map:main.main"
[tool.poetry.scripts] [tool.poetry.scripts]
map2map = "map2map:main" map2map = "map2map:main"