117 lines
3.5 KiB
Python
117 lines
3.5 KiB
Python
from . import train
|
|
from . import test
|
|
import click
|
|
import os
|
|
import yaml
|
|
try:
|
|
from yaml import CLoader as Loader
|
|
except ImportError:
|
|
from yaml import Loader
|
|
|
|
import importlib.resources
|
|
import json
|
|
from functools import partial
|
|
|
|
def _load_resource_file(resource_path):
|
|
# Import the package
|
|
pkg_files = importlib.resources.files() / resource_path
|
|
with pkg_files.open() as file:
|
|
return file.read() # Read the file and return its content
|
|
|
|
def _str_list(value):
|
|
return value.split(',')
|
|
|
|
def _int_tuple(value):
|
|
t = value.split(',')
|
|
t = tuple(int(i) for i in t)
|
|
return t
|
|
|
|
class VariadicType(click.ParamType):
|
|
"""
|
|
A custom parameter type for Click command-line interface.
|
|
|
|
This class provides a way to define custom parameter types for Click commands.
|
|
It supports various types such as string, integer, float, JSON, and file paths.
|
|
|
|
Args:
|
|
typename (str or dict): The name of the type or a dictionary specifying the type and options.
|
|
|
|
Raises:
|
|
ValueError: If the typename is not recognized.
|
|
"""
|
|
|
|
_mapper = {
|
|
"str_list": {"type": "string_list", "func": _str_list},
|
|
"int_tuple": {"type": "int_tuple", "func": _int_tuple},
|
|
"json": {"type": "json", "func": json.loads},
|
|
"int": {"type": "int"},
|
|
"float": {"type": "float"},
|
|
"str": {"type": "str"},
|
|
"abspath": {"type": "path", "func": os.path.abspath},
|
|
}
|
|
|
|
def __init__(self, typename):
|
|
if typename in self._mapper:
|
|
self._type = self._mapper[typename]
|
|
elif type(typename) == dict:
|
|
self._type = self._mapper[typename["type"]]
|
|
self.args = typename["opts"]
|
|
else:
|
|
raise ValueError(f"Unknown type: {typename}")
|
|
self._typename = typename
|
|
self.name = self._type["type"]
|
|
if "func" not in self._type:
|
|
self._type["func"] = eval(self._type['type'])
|
|
|
|
def convert(self, value, param, ctx):
|
|
try:
|
|
return self.type(value)
|
|
except Exception as e:
|
|
self.fail(f"Could not parse {self._typename}: {e}", param, ctx)
|
|
|
|
def _apply_options(options_file, f):
|
|
common_args = yaml.load(_load_resource_file(options_file), Loader=Loader)
|
|
common_args = common_args['arguments']
|
|
|
|
for arg in common_args:
|
|
argopt = common_args[arg]
|
|
if 'type' in argopt:
|
|
if type(argopt['type']) == dict and argopt['type']['type'] == 'choice':
|
|
argopt['type'] = click.Choice(argopt['type']['opts'])
|
|
else:
|
|
argopt['type'] = VariadicType(argopt['type'])
|
|
f = click.option(f'--{arg}', **argopt)(f)
|
|
else:
|
|
f = click.option(f'--{arg}', **argopt)(f)
|
|
|
|
return f
|
|
|
|
m2m_options=partial(_apply_options,"common_args.yaml")
|
|
|
|
|
|
@click.group()
|
|
@click.option("--config", type=click.Path(), help="Path to config file")
|
|
@click.pass_context
|
|
def main(ctx, config):
|
|
if config is not None and os.path.exists(config):
|
|
with open(config, 'r') as f:
|
|
config = yaml.load(f.read(), Loader=Loader)
|
|
ctx.default_map = config
|
|
|
|
# Make a class that provides access to dict with the attribute mechanism
|
|
class DictProxy:
|
|
def __init__(self, d):
|
|
self.__dict__ = d
|
|
|
|
@main.command()
|
|
@m2m_options
|
|
@partial(_apply_options, "train_args.yaml")
|
|
def train(**kwargs):
|
|
train.node_worker(DictProxy(kwargs))
|
|
|
|
@main.command()
|
|
@m2m_options
|
|
@partial(_apply_options, "test_args.yaml")
|
|
def test(**kwargs):
|
|
test.test(DictProxy(kwargs))
|