Add much more flexible argument handling

This commit is contained in:
Guilhem Lavaux 2024-04-03 17:00:41 +02:00
parent 1c68ed397e
commit 4579651bba
3 changed files with 118 additions and 81 deletions

View File

@ -18,7 +18,7 @@ arguments:
type: int_tuple type: int_tuple
help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers' help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers'
- 'in-pad': - 'in-pad':
'pad': 0 pad: 0
type: int_tuple type: int_tuple
help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively' help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively'
- 'tgt-pad': - 'tgt-pad':

View File

@ -1,17 +1,53 @@
from .args import get_args from .args import get_args
from . import train from . import train
from . import test from . import test
import click
import os
import yaml
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
import importlib.resources
def main(): def _load_resource_file(resource_path):
package = importlib.import_module('map2map') # Import the package
with importlib.resources.path('map2map', resource_path) as path:
return path.read_text() # Read the file and return its content
def str_list(s):
return s.split(',')
def m2m_options(f):
common_args = _load_resource_file('common_args.yaml')
for arg in common_args['arguments']:
argopt = common_args[arg]
if 'type' in argopt:
argopt['type'] = eval(argopt['type'])
f = click.option(f'--{arg}', **argopt)(f)
else:
f = click.option(f'--{arg}', **argopt)(f)
return f
@click.group()
@click.option("--config", type=click.Path())
@click.pass_context
def main(ctx, config):
if os.path.exists(config):
with open(config, 'r') as f:
config = yaml.load(f.read(), Loader=Loader)
ctx.default_map = config
@main.command()
@m2m_options
def train(**kwargs):
args = get_args() args = get_args()
if args.mode == 'train':
train.node_worker(args) train.node_worker(args)
elif args.mode == 'test':
@main.command()
@m2m_options
def test():
test.test(args) test.test(args)
if __name__ == '__main__':
main()

View File

@ -17,7 +17,8 @@ dependencies = [
'numpy', 'numpy',
'scipy', 'scipy',
'matplotlib', 'matplotlib',
'tensorboard'] 'tensorboard',
'click']
authors = [ authors = [
{name = "Yin Li", email = "eelregit@gmail.com"}, {name = "Yin Li", email = "eelregit@gmail.com"},
@ -29,7 +30,7 @@ maintainers = [
] ]
[project.scripts] [project.scripts]
m2m = "map2map:main" m2m = "map2map:main.main"
[project.urls] [project.urls]
#Homepage = "https://example.com" #Homepage = "https://example.com"