Add much more flexible argument handling
This commit is contained in:
parent
1c68ed397e
commit
4579651bba
@ -1,88 +1,88 @@
|
|||||||
arguments:
|
arguments:
|
||||||
- 'in-norms':
|
- 'in-norms':
|
||||||
type: str_list
|
type: str_list
|
||||||
help: 'comma-sep. list of input normalization functions'
|
help: 'comma-sep. list of input normalization functions'
|
||||||
- 'tgt-norms':
|
- 'tgt-norms':
|
||||||
type: str_list
|
type: str_list
|
||||||
help: 'comma-sep. list of target normalization functions'
|
help: 'comma-sep. list of target normalization functions'
|
||||||
- 'crop':
|
- 'crop':
|
||||||
type: int_tuple
|
type: int_tuple
|
||||||
help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers'
|
help: 'size to crop the input and target data. Default is the field size. Comma-sep. list of 1 or d integers'
|
||||||
- 'crop-start':
|
- 'crop-start':
|
||||||
type: int_tuple
|
type: int_tuple
|
||||||
help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers'
|
help: 'starting point of the first crop. Default is the origin. Comma-sep. list of 1 or d integers'
|
||||||
- 'crop-stop':
|
- 'crop-stop':
|
||||||
type: int_tuple
|
type: int_tuple
|
||||||
help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers'
|
help: 'stopping point of the last crop. Default is the opposite corner to the origin. Comma-sep. list of 1 or d integers'
|
||||||
- 'crop-step':
|
- 'crop-step':
|
||||||
type: int_tuple
|
type: int_tuple
|
||||||
help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers'
|
help: 'spacing between crops. Default is the crop size. Comma-sep. list of 1 or d integers'
|
||||||
- 'in-pad':
|
- 'in-pad':
|
||||||
'pad': 0
|
pad: 0
|
||||||
type: int_tuple
|
type: int_tuple
|
||||||
help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively'
|
help: 'size to pad the input data beyond the crop size, assuming periodic boundary condition. Comma-sep. list of 1, d, or dx2 integers, to pad equally along all axes, symmetrically on each, or by the specified size on every boundary, respectively'
|
||||||
- 'tgt-pad':
|
- 'tgt-pad':
|
||||||
default: 0
|
default: 0
|
||||||
type: int_tuple
|
type: int_tuple
|
||||||
help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad'
|
help: 'size to pad the target data beyond the crop size, assuming periodic boundary condition, useful for super-resolution. Comma-sep. list with the same format as in-pad'
|
||||||
- 'scale-factor':
|
- 'scale-factor':
|
||||||
default: 1
|
default: 1
|
||||||
type: int
|
type: int
|
||||||
help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution'
|
help: 'upsampling factor for super-resolution, in which case crop and pad are sizes of the input resolution'
|
||||||
- 'model':
|
- 'model':
|
||||||
type: str
|
type: str
|
||||||
required: true
|
required: true
|
||||||
help: '(generator) model'
|
help: '(generator) model'
|
||||||
- 'criterion':
|
- 'criterion':
|
||||||
default: 'MSELoss'
|
default: 'MSELoss'
|
||||||
type: str
|
type: str
|
||||||
help: 'loss function'
|
help: 'loss function'
|
||||||
- 'load-state':
|
- 'load-state':
|
||||||
default: ckpt_link
|
default: ckpt_link
|
||||||
type: str
|
type: str
|
||||||
help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint'
|
help: 'path to load the states of model, optimizer, rng, etc. Default is the checkpoint. Start from scratch in case of empty string or missing checkpoint'
|
||||||
- 'load-state-non-strict':
|
- 'load-state-non-strict':
|
||||||
action: 'store_false'
|
action: 'store_false'
|
||||||
help: 'allow incompatible keys when loading model states'
|
help: 'allow incompatible keys when loading model states'
|
||||||
dest: 'load_state_strict'
|
dest: 'load_state_strict'
|
||||||
- 'batch-size':
|
- 'batch-size':
|
||||||
'batches': 0
|
'batches': 0
|
||||||
type: int
|
type: int
|
||||||
required: true
|
required: true
|
||||||
help: 'mini-batch size, per GPU in training or in total in testing'
|
help: 'mini-batch size, per GPU in training or in total in testing'
|
||||||
- 'loader-workers':
|
- 'loader-workers':
|
||||||
default: 8
|
default: 8
|
||||||
type: int
|
type: int
|
||||||
help: 'number of subprocesses per data loader. 0 to disable multiprocessing'
|
help: 'number of subprocesses per data loader. 0 to disable multiprocessing'
|
||||||
- 'callback-at':
|
- 'callback-at':
|
||||||
type: 'lambda s: os.path.abspath(s)'
|
type: 'lambda s: os.path.abspath(s)'
|
||||||
help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority'
|
help: 'directory of custorm code defining callbacks for models, norms, criteria, and optimizers. Disabled if not set. This is appended to the default locations, thus has the lowest priority'
|
||||||
- 'misc-kwargs':
|
- 'misc-kwargs':
|
||||||
default: '{}'
|
default: '{}'
|
||||||
type: json.loads
|
type: json.loads
|
||||||
help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions'
|
help: 'miscellaneous keyword arguments for custom models and norms. Be careful with name collisions'
|
||||||
arguments:
|
arguments:
|
||||||
- 'optimizer':
|
- 'optimizer':
|
||||||
default: 'Adam'
|
default: 'Adam'
|
||||||
type: str
|
type: str
|
||||||
help: 'optimizer for training'
|
help: 'optimizer for training'
|
||||||
- 'learning-rate':
|
- 'learning-rate':
|
||||||
default: 0.001
|
default: 0.001
|
||||||
type: float
|
type: float
|
||||||
help: 'learning rate for training'
|
help: 'learning rate for training'
|
||||||
- 'num-epochs':
|
- 'num-epochs':
|
||||||
default: 100
|
default: 100
|
||||||
type: int
|
type: int
|
||||||
help: 'number of training epochs'
|
help: 'number of training epochs'
|
||||||
- 'save-interval':
|
- 'save-interval':
|
||||||
default: 10
|
default: 10
|
||||||
type: int
|
type: int
|
||||||
help: 'interval for saving checkpoints during training'
|
help: 'interval for saving checkpoints during training'
|
||||||
- 'log-interval':
|
- 'log-interval':
|
||||||
default: 10
|
default: 10
|
||||||
type: int
|
type: int
|
||||||
help: 'interval for logging training progress'
|
help: 'interval for logging training progress'
|
||||||
- 'device':
|
- 'device':
|
||||||
default: 'cuda'
|
default: 'cuda'
|
||||||
type: str
|
type: str
|
||||||
help: 'device for training (cuda or cpu)'
|
help: 'device for training (cuda or cpu)'
|
@ -1,17 +1,53 @@
|
|||||||
from .args import get_args
|
from .args import get_args
|
||||||
from . import train
|
from . import train
|
||||||
from . import test
|
from . import test
|
||||||
|
import click
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
try:
|
||||||
|
from yaml import CLoader as Loader
|
||||||
|
except ImportError:
|
||||||
|
from yaml import Loader
|
||||||
|
|
||||||
|
import importlib.resources
|
||||||
|
|
||||||
def main():
|
def _load_resource_file(resource_path):
|
||||||
|
package = importlib.import_module('map2map') # Import the package
|
||||||
|
with importlib.resources.path('map2map', resource_path) as path:
|
||||||
|
return path.read_text() # Read the file and return its content
|
||||||
|
|
||||||
|
def str_list(s):
|
||||||
|
return s.split(',')
|
||||||
|
|
||||||
|
def m2m_options(f):
|
||||||
|
common_args = _load_resource_file('common_args.yaml')
|
||||||
|
|
||||||
|
for arg in common_args['arguments']:
|
||||||
|
argopt = common_args[arg]
|
||||||
|
if 'type' in argopt:
|
||||||
|
argopt['type'] = eval(argopt['type'])
|
||||||
|
f = click.option(f'--{arg}', **argopt)(f)
|
||||||
|
else:
|
||||||
|
f = click.option(f'--{arg}', **argopt)(f)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
@click.option("--config", type=click.Path())
|
||||||
|
@click.pass_context
|
||||||
|
def main(ctx, config):
|
||||||
|
if os.path.exists(config):
|
||||||
|
with open(config, 'r') as f:
|
||||||
|
config = yaml.load(f.read(), Loader=Loader)
|
||||||
|
ctx.default_map = config
|
||||||
|
|
||||||
|
@main.command()
|
||||||
|
@m2m_options
|
||||||
|
def train(**kwargs):
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
train.node_worker(args)
|
||||||
|
|
||||||
if args.mode == 'train':
|
@main.command()
|
||||||
train.node_worker(args)
|
@m2m_options
|
||||||
elif args.mode == 'test':
|
def test():
|
||||||
test.test(args)
|
test.test(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
|
@ -17,7 +17,8 @@ dependencies = [
|
|||||||
'numpy',
|
'numpy',
|
||||||
'scipy',
|
'scipy',
|
||||||
'matplotlib',
|
'matplotlib',
|
||||||
'tensorboard']
|
'tensorboard',
|
||||||
|
'click']
|
||||||
|
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Yin Li", email = "eelregit@gmail.com"},
|
{name = "Yin Li", email = "eelregit@gmail.com"},
|
||||||
@ -29,7 +30,7 @@ maintainers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
m2m = "map2map:main"
|
m2m = "map2map:main.main"
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
#Homepage = "https://example.com"
|
#Homepage = "https://example.com"
|
||||||
|
Loading…
Reference in New Issue
Block a user