Add callback loading mechanism

This commit is contained in:
Yin Li 2020-06-14 17:59:31 -04:00
parent 5bb2a19933
commit 90a0d6e0f8
10 changed files with 110 additions and 36 deletions

View File

@ -7,6 +7,7 @@ Neural network emulators to transform field/map data
* [Data](#data) * [Data](#data)
* [Data normalization](#data-normalization) * [Data normalization](#data-normalization)
* [Model](#model) * [Model](#model)
* [Customization](#customization)
* [Training](#training) * [Training](#training)
* [Files generated](#files-generated) * [Files generated](#files-generated)
* [Tracking](#tracking) * [Tracking](#tracking)
@ -23,33 +24,63 @@ pip install -e .
## Usage ## Usage
Take a look at the examples in `scripts/*.slurm`, and the command line options The command is `m2m.py` in your `$PATH` after installation.
in `map2map/args.py` or by `m2m.py -h`. Take a look at the examples in `scripts/*.slurm`.
For all command line options look at `map2map/args.py` or do `m2m.py -h`.
### Data ### Data
Put each field in one npy file.
Structure your data to start with the channel axis and then the spatial Structure your data to start with the channel axis and then the spatial
dimensions. dimensions.
Put each sample in one file. For example a 2D vector field of size `64^2` should have shape `(2, 64,
Specify the data path with glob patterns. 64)`.
Specify the data path with
[glob patterns](https://docs.python.org/3/library/glob.html).
During training, pairs of input and target fields are loaded.
Both input and target data can consist of multiple fields, which are
then concatenated along the channel axis.
If the size of a pair of input and target fields is too large to fit in
a GPU, we can crop part of them to form pairs of samples (see `--crop`).
Each field can be cropped multiple times, along each dimension,
controlled by the spacing between two adjacent crops (see `--step`).
The total sample size is the number of input and target pairs multiplied
by the number of cropped samples per pair.
#### Data normalization #### Data normalization
Input and target (output) data can be normalized by functions defined in Input and target (output) data can be normalized by functions defined in
`map2map2/data/norms/`. `map2map2/data/norms/`.
Also see [Customization](#customization).
### Model ### Model
Find the models in `map2map/models/`. Find the models in `map2map/models/`.
Customize the existing models, or add new models there and edit the `__init__.py`. Modify the existing models, or write new models somewhere and then
follow [Customization](#customization).
### Training ### Training
### Customization
Models, criteria, optimizers and data normalizations can be customized
without modifying map2map.
They can be implemented as callbacks in a user directory which is then
passed by `--callback-at`.
The default locations are searched first before the callback directory.
So be aware of name collisions.
This approach is good for experimentation.
For example, one can play with a model `Bar` in `./foo.py`, by calling
`m2m.py` with `--model foo.Bar --callback-at .`
#### Files generated #### Files generated
* `*.out`: job stdout and stderr * `*.out`: job stdout and stderr

View File

@ -1,3 +1,4 @@
import os
import argparse import argparse
import warnings import warnings
@ -72,6 +73,11 @@ def add_common_args(parser):
help='maximum pairs of fields in cache, unlimited by default. ' help='maximum pairs of fields in cache, unlimited by default. '
'This only applies to training if not None, ' 'This only applies to training if not None, '
'in which case the testing cache maxsize is 1') 'in which case the testing cache maxsize is 1')
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. '
'This is appended to the default locations, '
'thus has the lowest priority.')
def add_train_args(parser): def add_train_args(parser):

View File

@ -5,7 +5,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import Dataset from torch.utils.data import Dataset
from .norms import import_norm from ..utils import import_attr
from . import norms
class FieldDataset(Dataset): class FieldDataset(Dataset):
@ -38,7 +39,7 @@ class FieldDataset(Dataset):
This saves CPU RAM but limits stochasticity. This saves CPU RAM but limits stochasticity.
""" """
def __init__(self, in_patterns, tgt_patterns, def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None, augment=False, aug_add=None, aug_mul=None,
crop=None, pad=0, scale_factor=1, crop=None, pad=0, scale_factor=1,
cache=False, cache_maxsize=None, div_data=False, cache=False, cache_maxsize=None, div_data=False,
@ -65,13 +66,15 @@ class FieldDataset(Dataset):
if in_norms is not None: if in_norms is not None:
assert len(in_patterns) == len(in_norms), \ assert len(in_patterns) == len(in_norms), \
'numbers of input normalization functions and fields do not match' 'numbers of input normalization functions and fields do not match'
in_norms = [import_norm(norm) for norm in in_norms] in_norms = [import_attr(norm, norms.__name__, callback_at)
for norm in in_norms]
self.in_norms = in_norms self.in_norms = in_norms
if tgt_norms is not None: if tgt_norms is not None:
assert len(tgt_patterns) == len(tgt_norms), \ assert len(tgt_patterns) == len(tgt_norms), \
'numbers of target normalization functions and fields do not match' 'numbers of target normalization functions and fields do not match'
tgt_norms = [import_norm(norm) for norm in tgt_norms] tgt_norms = [import_attr(norm, norms.__name__, callback_at)
for norm in tgt_norms]
self.tgt_norms = tgt_norms self.tgt_norms = tgt_norms
self.augment = augment self.augment = augment

View File

@ -1,13 +0,0 @@
from importlib import import_module
from . import cosmology
def import_norm(norm):
if callable(norm):
return norm
mod, fun = norm.rsplit('.', 1)
mod = import_module('.' + mod, __name__)
fun = getattr(mod, fun)
return fun

View File

@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
from .data import FieldDataset from .data import FieldDataset
from . import models from . import models
from .models import narrow_like from .models import narrow_like
from .state import load_model_state_dict from .utils import import_attr, load_model_state_dict
def test(args): def test(args):
@ -18,6 +18,7 @@ def test(args):
tgt_patterns=args.test_tgt_patterns, tgt_patterns=args.test_tgt_patterns,
in_norms=args.in_norms, in_norms=args.in_norms,
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False, augment=False,
aug_add=None, aug_add=None,
aug_mul=None, aug_mul=None,
@ -35,9 +36,9 @@ def test(args):
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
model = getattr(models, args.model) model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(in_chan), sum(out_chan)) model = model(sum(in_chan), sum(out_chan))
criterion = getattr(torch.nn, args.criterion) criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at)
criterion = criterion() criterion = criterion()
device = torch.device('cpu') device = torch.device('cpu')

View File

@ -6,6 +6,7 @@ from pprint import pprint
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist import torch.distributed as dist
from torch.multiprocessing import spawn from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
@ -20,7 +21,7 @@ from .models import (narrow_like,
adv_model_wrapper, adv_criterion_wrapper, adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm, add_spectral_norm, rm_spectral_norm,
InstanceNoise) InstanceNoise)
from .state import load_model_state_dict from .utils import import_attr, load_model_state_dict
ckpt_link = 'checkpoint.pth' ckpt_link = 'checkpoint.pth'
@ -62,6 +63,7 @@ def gpu_worker(local_rank, node, args):
tgt_patterns=args.train_tgt_patterns, tgt_patterns=args.train_tgt_patterns,
in_norms=args.in_norms, in_norms=args.in_norms,
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=args.augment, augment=args.augment,
aug_add=args.aug_add, aug_add=args.aug_add,
aug_mul=args.aug_mul, aug_mul=args.aug_mul,
@ -100,6 +102,7 @@ def gpu_worker(local_rank, node, args):
tgt_patterns=args.val_tgt_patterns, tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms, in_norms=args.in_norms,
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False, augment=False,
aug_add=None, aug_add=None,
aug_mul=None, aug_mul=None,
@ -130,17 +133,17 @@ def gpu_worker(local_rank, node, args):
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = getattr(models, args.model) model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan)) model = model(sum(args.in_chan), sum(args.out_chan))
model.to(device) model.to(device)
model = DistributedDataParallel(model, device_ids=[device], model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
criterion = getattr(nn, args.criterion) criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
criterion = criterion() criterion = criterion()
criterion.to(device) criterion.to(device)
optimizer = getattr(torch.optim, args.optimizer) optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
optimizer = optimizer( optimizer = optimizer(
model.parameters(), model.parameters(),
lr=args.lr, lr=args.lr,
@ -148,12 +151,12 @@ def gpu_worker(local_rank, node, args):
betas=(0.5, 0.999), betas=(0.5, 0.999),
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
) )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.1, patience=10, verbose=True) factor=0.1, patience=10, verbose=True)
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
if args.adv: if args.adv:
adv_model = getattr(models, args.adv_model) adv_model = import_attr(args.adv_model, models.__name__, args.callback_at)
adv_model = adv_model_wrapper(adv_model) adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(args.in_chan + args.out_chan) adv_model = adv_model(sum(args.in_chan + args.out_chan)
if args.cgan else sum(args.out_chan), 1) if args.cgan else sum(args.out_chan), 1)
@ -163,19 +166,19 @@ def gpu_worker(local_rank, node, args):
adv_model = DistributedDataParallel(adv_model, device_ids=[device], adv_model = DistributedDataParallel(adv_model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
adv_criterion = getattr(nn, args.adv_criterion) adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
adv_criterion = adv_criterion_wrapper(adv_criterion) adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean') adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
adv_criterion.to(device) adv_criterion.to(device)
adv_optimizer = getattr(torch.optim, args.optimizer) adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
adv_optimizer = adv_optimizer( adv_optimizer = adv_optimizer(
adv_model.parameters(), adv_model.parameters(),
lr=args.adv_lr, lr=args.adv_lr,
betas=(0.5, 0.999), betas=(0.5, 0.999),
weight_decay=args.adv_weight_decay, weight_decay=args.adv_weight_decay,
) )
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer, adv_scheduler = optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
factor=0.1, patience=10, verbose=True) factor=0.1, patience=10, verbose=True)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link) if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)

View File

@ -0,0 +1,2 @@
from .imp import import_attr
from .state import load_model_state_dict

41
map2map/utils/imp.py Normal file
View File

@ -0,0 +1,41 @@
import os
import importlib
def import_attr(name, pkg, callback_at=None):
"""Import attribute. Try package first and then callback directory.
To use a callback, `name` must contain a module, formatted as 'mod.attr'.
Examples
--------
>>> import_attr('attr', 'pkg1.pkg2')
tries to import attr from pkg1.pkg2.
>>> import_attr('mod.attr', 'pkg1.pkg2', 'path/to/cb_dir')
first tries to import attr from pkg1.pkg2.mod, then from
'path/to/cb_dir/mod.py'.
"""
if name.count('.') == 0:
attr = name
return getattr(importlib.import_module(pkg), attr)
else:
mod, attr = name.rsplit('.', 1)
try:
return getattr(importlib.import_module(pkg + '.' + mod), attr)
except (ModuleNotFoundError, AttributeError):
if callback_at is None:
raise
callback_at = os.path.join(callback_at, mod + '.py')
assert os.path.isfile(callback_at), 'callback file not found'
spec = importlib.util.spec_from_file_location(mod, callback_at)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, attr)

View File

@ -1,5 +1,5 @@
import warnings
import sys import sys
import warnings
from pprint import pformat from pprint import pformat

View File

@ -8,7 +8,7 @@ setup(
author='Yin Li et al.', author='Yin Li et al.',
author_email='eelregit@gmail.com', author_email='eelregit@gmail.com',
packages=find_packages(), packages=find_packages(),
python_requires='>=3.2', python_requires='>=3.6',
install_requires=[ install_requires=[
'torch', 'torch',
'numpy', 'numpy',