Add callback loading mechanism

This commit is contained in:
Yin Li 2020-06-14 17:59:31 -04:00
parent 5bb2a19933
commit 90a0d6e0f8
10 changed files with 110 additions and 36 deletions

View File

@ -7,6 +7,7 @@ Neural network emulators to transform field/map data
* [Data](#data)
* [Data normalization](#data-normalization)
* [Model](#model)
* [Customization](#customization)
* [Training](#training)
* [Files generated](#files-generated)
* [Tracking](#tracking)
@ -23,33 +24,63 @@ pip install -e .
## Usage
Take a look at the examples in `scripts/*.slurm`, and the command line options
in `map2map/args.py` or by `m2m.py -h`.
The command is `m2m.py` in your `$PATH` after installation.
Take a look at the examples in `scripts/*.slurm`.
For all command line options look at `map2map/args.py` or do `m2m.py -h`.
### Data
Put each field in one npy file.
Structure your data to start with the channel axis and then the spatial
dimensions.
Put each sample in one file.
Specify the data path with glob patterns.
For example a 2D vector field of size `64^2` should have shape `(2, 64,
64)`.
Specify the data path with
[glob patterns](https://docs.python.org/3/library/glob.html).
During training, pairs of input and target fields are loaded.
Both input and target data can consist of multiple fields, which are
then concatenated along the channel axis.
If the size of a pair of input and target fields is too large to fit in
a GPU, we can crop part of them to form pairs of samples (see `--crop`).
Each field can be cropped multiple times, along each dimension,
controlled by the spacing between two adjacent crops (see `--step`).
The total sample size is the number of input and target pairs multiplied
by the number of cropped samples per pair.
#### Data normalization
Input and target (output) data can be normalized by functions defined in
`map2map2/data/norms/`.
Also see [Customization](#customization).
### Model
Find the models in `map2map/models/`.
Customize the existing models, or add new models there and edit the `__init__.py`.
Modify the existing models, or write new models somewhere and then
follow [Customization](#customization).
### Training
### Customization
Models, criteria, optimizers and data normalizations can be customized
without modifying map2map.
They can be implemented as callbacks in a user directory which is then
passed by `--callback-at`.
The default locations are searched first before the callback directory.
So be aware of name collisions.
This approach is good for experimentation.
For example, one can play with a model `Bar` in `./foo.py`, by calling
`m2m.py` with `--model foo.Bar --callback-at .`
#### Files generated
* `*.out`: job stdout and stderr

View File

@ -1,3 +1,4 @@
import os
import argparse
import warnings
@ -72,6 +73,11 @@ def add_common_args(parser):
help='maximum pairs of fields in cache, unlimited by default. '
'This only applies to training if not None, '
'in which case the testing cache maxsize is 1')
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. '
'This is appended to the default locations, '
'thus has the lowest priority.')
def add_train_args(parser):

View File

@ -5,7 +5,8 @@ import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from .norms import import_norm
from ..utils import import_attr
from . import norms
class FieldDataset(Dataset):
@ -38,7 +39,7 @@ class FieldDataset(Dataset):
This saves CPU RAM but limits stochasticity.
"""
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None,
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None,
crop=None, pad=0, scale_factor=1,
cache=False, cache_maxsize=None, div_data=False,
@ -65,13 +66,15 @@ class FieldDataset(Dataset):
if in_norms is not None:
assert len(in_patterns) == len(in_norms), \
'numbers of input normalization functions and fields do not match'
in_norms = [import_norm(norm) for norm in in_norms]
in_norms = [import_attr(norm, norms.__name__, callback_at)
for norm in in_norms]
self.in_norms = in_norms
if tgt_norms is not None:
assert len(tgt_patterns) == len(tgt_norms), \
'numbers of target normalization functions and fields do not match'
tgt_norms = [import_norm(norm) for norm in tgt_norms]
tgt_norms = [import_attr(norm, norms.__name__, callback_at)
for norm in tgt_norms]
self.tgt_norms = tgt_norms
self.augment = augment

View File

@ -1,13 +0,0 @@
from importlib import import_module
from . import cosmology
def import_norm(norm):
if callable(norm):
return norm
mod, fun = norm.rsplit('.', 1)
mod = import_module('.' + mod, __name__)
fun = getattr(mod, fun)
return fun

View File

@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
from .data import FieldDataset
from . import models
from .models import narrow_like
from .state import load_model_state_dict
from .utils import import_attr, load_model_state_dict
def test(args):
@ -18,6 +18,7 @@ def test(args):
tgt_patterns=args.test_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_add=None,
aug_mul=None,
@ -35,9 +36,9 @@ def test(args):
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
model = getattr(models, args.model)
model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(in_chan), sum(out_chan))
criterion = getattr(torch.nn, args.criterion)
criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at)
criterion = criterion()
device = torch.device('cpu')

View File

@ -6,6 +6,7 @@ from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel
@ -20,7 +21,7 @@ from .models import (narrow_like,
adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm,
InstanceNoise)
from .state import load_model_state_dict
from .utils import import_attr, load_model_state_dict
ckpt_link = 'checkpoint.pth'
@ -62,6 +63,7 @@ def gpu_worker(local_rank, node, args):
tgt_patterns=args.train_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=args.augment,
aug_add=args.aug_add,
aug_mul=args.aug_mul,
@ -100,6 +102,7 @@ def gpu_worker(local_rank, node, args):
tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_add=None,
aug_mul=None,
@ -130,17 +133,17 @@ def gpu_worker(local_rank, node, args):
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = getattr(models, args.model)
model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan))
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group())
criterion = getattr(nn, args.criterion)
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
criterion = criterion()
criterion.to(device)
optimizer = getattr(torch.optim, args.optimizer)
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
optimizer = optimizer(
model.parameters(),
lr=args.lr,
@ -148,12 +151,12 @@ def gpu_worker(local_rank, node, args):
betas=(0.5, 0.999),
weight_decay=args.weight_decay,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.1, patience=10, verbose=True)
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
if args.adv:
adv_model = getattr(models, args.adv_model)
adv_model = import_attr(args.adv_model, models.__name__, args.callback_at)
adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(args.in_chan + args.out_chan)
if args.cgan else sum(args.out_chan), 1)
@ -163,19 +166,19 @@ def gpu_worker(local_rank, node, args):
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
process_group=dist.new_group())
adv_criterion = getattr(nn, args.adv_criterion)
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
adv_criterion.to(device)
adv_optimizer = getattr(torch.optim, args.optimizer)
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
adv_optimizer = adv_optimizer(
adv_model.parameters(),
lr=args.adv_lr,
betas=(0.5, 0.999),
weight_decay=args.adv_weight_decay,
)
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
adv_scheduler = optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
factor=0.1, patience=10, verbose=True)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)

View File

@ -0,0 +1,2 @@
from .imp import import_attr
from .state import load_model_state_dict

41
map2map/utils/imp.py Normal file
View File

@ -0,0 +1,41 @@
import os
import importlib
def import_attr(name, pkg, callback_at=None):
"""Import attribute. Try package first and then callback directory.
To use a callback, `name` must contain a module, formatted as 'mod.attr'.
Examples
--------
>>> import_attr('attr', 'pkg1.pkg2')
tries to import attr from pkg1.pkg2.
>>> import_attr('mod.attr', 'pkg1.pkg2', 'path/to/cb_dir')
first tries to import attr from pkg1.pkg2.mod, then from
'path/to/cb_dir/mod.py'.
"""
if name.count('.') == 0:
attr = name
return getattr(importlib.import_module(pkg), attr)
else:
mod, attr = name.rsplit('.', 1)
try:
return getattr(importlib.import_module(pkg + '.' + mod), attr)
except (ModuleNotFoundError, AttributeError):
if callback_at is None:
raise
callback_at = os.path.join(callback_at, mod + '.py')
assert os.path.isfile(callback_at), 'callback file not found'
spec = importlib.util.spec_from_file_location(mod, callback_at)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, attr)

View File

@ -1,5 +1,5 @@
import warnings
import sys
import warnings
from pprint import pformat

View File

@ -8,7 +8,7 @@ setup(
author='Yin Li et al.',
author_email='eelregit@gmail.com',
packages=find_packages(),
python_requires='>=3.2',
python_requires='>=3.6',
install_requires=[
'torch',
'numpy',