Add callback loading mechanism
This commit is contained in:
parent
5bb2a19933
commit
90a0d6e0f8
41
README.md
41
README.md
@ -7,6 +7,7 @@ Neural network emulators to transform field/map data
|
||||
* [Data](#data)
|
||||
* [Data normalization](#data-normalization)
|
||||
* [Model](#model)
|
||||
* [Customization](#customization)
|
||||
* [Training](#training)
|
||||
* [Files generated](#files-generated)
|
||||
* [Tracking](#tracking)
|
||||
@ -23,33 +24,63 @@ pip install -e .
|
||||
|
||||
## Usage
|
||||
|
||||
Take a look at the examples in `scripts/*.slurm`, and the command line options
|
||||
in `map2map/args.py` or by `m2m.py -h`.
|
||||
The command is `m2m.py` in your `$PATH` after installation.
|
||||
Take a look at the examples in `scripts/*.slurm`.
|
||||
For all command line options look at `map2map/args.py` or do `m2m.py -h`.
|
||||
|
||||
|
||||
### Data
|
||||
|
||||
Put each field in one npy file.
|
||||
Structure your data to start with the channel axis and then the spatial
|
||||
dimensions.
|
||||
Put each sample in one file.
|
||||
Specify the data path with glob patterns.
|
||||
For example a 2D vector field of size `64^2` should have shape `(2, 64,
|
||||
64)`.
|
||||
Specify the data path with
|
||||
[glob patterns](https://docs.python.org/3/library/glob.html).
|
||||
|
||||
During training, pairs of input and target fields are loaded.
|
||||
Both input and target data can consist of multiple fields, which are
|
||||
then concatenated along the channel axis.
|
||||
If the size of a pair of input and target fields is too large to fit in
|
||||
a GPU, we can crop part of them to form pairs of samples (see `--crop`).
|
||||
Each field can be cropped multiple times, along each dimension,
|
||||
controlled by the spacing between two adjacent crops (see `--step`).
|
||||
The total sample size is the number of input and target pairs multiplied
|
||||
by the number of cropped samples per pair.
|
||||
|
||||
|
||||
#### Data normalization
|
||||
|
||||
Input and target (output) data can be normalized by functions defined in
|
||||
`map2map2/data/norms/`.
|
||||
Also see [Customization](#customization).
|
||||
|
||||
|
||||
### Model
|
||||
|
||||
Find the models in `map2map/models/`.
|
||||
Customize the existing models, or add new models there and edit the `__init__.py`.
|
||||
Modify the existing models, or write new models somewhere and then
|
||||
follow [Customization](#customization).
|
||||
|
||||
|
||||
### Training
|
||||
|
||||
|
||||
### Customization
|
||||
|
||||
Models, criteria, optimizers and data normalizations can be customized
|
||||
without modifying map2map.
|
||||
They can be implemented as callbacks in a user directory which is then
|
||||
passed by `--callback-at`.
|
||||
The default locations are searched first before the callback directory.
|
||||
So be aware of name collisions.
|
||||
|
||||
This approach is good for experimentation.
|
||||
For example, one can play with a model `Bar` in `./foo.py`, by calling
|
||||
`m2m.py` with `--model foo.Bar --callback-at .`
|
||||
|
||||
|
||||
#### Files generated
|
||||
|
||||
* `*.out`: job stdout and stderr
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
@ -72,6 +73,11 @@ def add_common_args(parser):
|
||||
help='maximum pairs of fields in cache, unlimited by default. '
|
||||
'This only applies to training if not None, '
|
||||
'in which case the testing cache maxsize is 1')
|
||||
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
|
||||
help='directory of custorm code defining callbacks for models, '
|
||||
'norms, criteria, and optimizers. Disabled if not set. '
|
||||
'This is appended to the default locations, '
|
||||
'thus has the lowest priority.')
|
||||
|
||||
|
||||
def add_train_args(parser):
|
||||
|
@ -5,7 +5,8 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from .norms import import_norm
|
||||
from ..utils import import_attr
|
||||
from . import norms
|
||||
|
||||
|
||||
class FieldDataset(Dataset):
|
||||
@ -38,7 +39,7 @@ class FieldDataset(Dataset):
|
||||
This saves CPU RAM but limits stochasticity.
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None,
|
||||
in_norms=None, tgt_norms=None, callback_at=None,
|
||||
augment=False, aug_add=None, aug_mul=None,
|
||||
crop=None, pad=0, scale_factor=1,
|
||||
cache=False, cache_maxsize=None, div_data=False,
|
||||
@ -65,13 +66,15 @@ class FieldDataset(Dataset):
|
||||
if in_norms is not None:
|
||||
assert len(in_patterns) == len(in_norms), \
|
||||
'numbers of input normalization functions and fields do not match'
|
||||
in_norms = [import_norm(norm) for norm in in_norms]
|
||||
in_norms = [import_attr(norm, norms.__name__, callback_at)
|
||||
for norm in in_norms]
|
||||
self.in_norms = in_norms
|
||||
|
||||
if tgt_norms is not None:
|
||||
assert len(tgt_patterns) == len(tgt_norms), \
|
||||
'numbers of target normalization functions and fields do not match'
|
||||
tgt_norms = [import_norm(norm) for norm in tgt_norms]
|
||||
tgt_norms = [import_attr(norm, norms.__name__, callback_at)
|
||||
for norm in tgt_norms]
|
||||
self.tgt_norms = tgt_norms
|
||||
|
||||
self.augment = augment
|
||||
|
@ -1,13 +0,0 @@
|
||||
from importlib import import_module
|
||||
|
||||
from . import cosmology
|
||||
|
||||
|
||||
def import_norm(norm):
|
||||
if callable(norm):
|
||||
return norm
|
||||
|
||||
mod, fun = norm.rsplit('.', 1)
|
||||
mod = import_module('.' + mod, __name__)
|
||||
fun = getattr(mod, fun)
|
||||
return fun
|
@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
|
||||
from .data import FieldDataset
|
||||
from . import models
|
||||
from .models import narrow_like
|
||||
from .state import load_model_state_dict
|
||||
from .utils import import_attr, load_model_state_dict
|
||||
|
||||
|
||||
def test(args):
|
||||
@ -18,6 +18,7 @@ def test(args):
|
||||
tgt_patterns=args.test_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
callback_at=args.callback_at,
|
||||
augment=False,
|
||||
aug_add=None,
|
||||
aug_mul=None,
|
||||
@ -35,9 +36,9 @@ def test(args):
|
||||
|
||||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||
|
||||
model = getattr(models, args.model)
|
||||
model = import_attr(args.model, models.__name__, args.callback_at)
|
||||
model = model(sum(in_chan), sum(out_chan))
|
||||
criterion = getattr(torch.nn, args.criterion)
|
||||
criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at)
|
||||
criterion = criterion()
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
@ -6,6 +6,7 @@ from pprint import pprint
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import spawn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
@ -20,7 +21,7 @@ from .models import (narrow_like,
|
||||
adv_model_wrapper, adv_criterion_wrapper,
|
||||
add_spectral_norm, rm_spectral_norm,
|
||||
InstanceNoise)
|
||||
from .state import load_model_state_dict
|
||||
from .utils import import_attr, load_model_state_dict
|
||||
|
||||
|
||||
ckpt_link = 'checkpoint.pth'
|
||||
@ -62,6 +63,7 @@ def gpu_worker(local_rank, node, args):
|
||||
tgt_patterns=args.train_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
callback_at=args.callback_at,
|
||||
augment=args.augment,
|
||||
aug_add=args.aug_add,
|
||||
aug_mul=args.aug_mul,
|
||||
@ -100,6 +102,7 @@ def gpu_worker(local_rank, node, args):
|
||||
tgt_patterns=args.val_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
callback_at=args.callback_at,
|
||||
augment=False,
|
||||
aug_add=None,
|
||||
aug_mul=None,
|
||||
@ -130,17 +133,17 @@ def gpu_worker(local_rank, node, args):
|
||||
|
||||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||
|
||||
model = getattr(models, args.model)
|
||||
model = import_attr(args.model, models.__name__, args.callback_at)
|
||||
model = model(sum(args.in_chan), sum(args.out_chan))
|
||||
model.to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
criterion = getattr(nn, args.criterion)
|
||||
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
||||
criterion = criterion()
|
||||
criterion.to(device)
|
||||
|
||||
optimizer = getattr(torch.optim, args.optimizer)
|
||||
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||
optimizer = optimizer(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
@ -148,12 +151,12 @@ def gpu_worker(local_rank, node, args):
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
|
||||
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
|
||||
if args.adv:
|
||||
adv_model = getattr(models, args.adv_model)
|
||||
adv_model = import_attr(args.adv_model, models.__name__, args.callback_at)
|
||||
adv_model = adv_model_wrapper(adv_model)
|
||||
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
||||
if args.cgan else sum(args.out_chan), 1)
|
||||
@ -163,19 +166,19 @@ def gpu_worker(local_rank, node, args):
|
||||
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
adv_criterion = getattr(nn, args.adv_criterion)
|
||||
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
|
||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
||||
adv_criterion.to(device)
|
||||
|
||||
adv_optimizer = getattr(torch.optim, args.optimizer)
|
||||
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||
adv_optimizer = adv_optimizer(
|
||||
adv_model.parameters(),
|
||||
lr=args.adv_lr,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=args.adv_weight_decay,
|
||||
)
|
||||
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
||||
adv_scheduler = optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
|
||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||
|
@ -0,0 +1,2 @@
|
||||
from .imp import import_attr
|
||||
from .state import load_model_state_dict
|
41
map2map/utils/imp.py
Normal file
41
map2map/utils/imp.py
Normal file
@ -0,0 +1,41 @@
|
||||
import os
|
||||
import importlib
|
||||
|
||||
|
||||
def import_attr(name, pkg, callback_at=None):
|
||||
"""Import attribute. Try package first and then callback directory.
|
||||
|
||||
To use a callback, `name` must contain a module, formatted as 'mod.attr'.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import_attr('attr', 'pkg1.pkg2')
|
||||
|
||||
tries to import attr from pkg1.pkg2.
|
||||
|
||||
>>> import_attr('mod.attr', 'pkg1.pkg2', 'path/to/cb_dir')
|
||||
|
||||
first tries to import attr from pkg1.pkg2.mod, then from
|
||||
'path/to/cb_dir/mod.py'.
|
||||
"""
|
||||
if name.count('.') == 0:
|
||||
attr = name
|
||||
|
||||
return getattr(importlib.import_module(pkg), attr)
|
||||
else:
|
||||
mod, attr = name.rsplit('.', 1)
|
||||
|
||||
try:
|
||||
return getattr(importlib.import_module(pkg + '.' + mod), attr)
|
||||
except (ModuleNotFoundError, AttributeError):
|
||||
if callback_at is None:
|
||||
raise
|
||||
|
||||
callback_at = os.path.join(callback_at, mod + '.py')
|
||||
assert os.path.isfile(callback_at), 'callback file not found'
|
||||
|
||||
spec = importlib.util.spec_from_file_location(mod, callback_at)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
return getattr(module, attr)
|
@ -1,5 +1,5 @@
|
||||
import warnings
|
||||
import sys
|
||||
import warnings
|
||||
from pprint import pformat
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user