Add callback loading mechanism
This commit is contained in:
parent
5bb2a19933
commit
90a0d6e0f8
41
README.md
41
README.md
@ -7,6 +7,7 @@ Neural network emulators to transform field/map data
|
|||||||
* [Data](#data)
|
* [Data](#data)
|
||||||
* [Data normalization](#data-normalization)
|
* [Data normalization](#data-normalization)
|
||||||
* [Model](#model)
|
* [Model](#model)
|
||||||
|
* [Customization](#customization)
|
||||||
* [Training](#training)
|
* [Training](#training)
|
||||||
* [Files generated](#files-generated)
|
* [Files generated](#files-generated)
|
||||||
* [Tracking](#tracking)
|
* [Tracking](#tracking)
|
||||||
@ -23,33 +24,63 @@ pip install -e .
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Take a look at the examples in `scripts/*.slurm`, and the command line options
|
The command is `m2m.py` in your `$PATH` after installation.
|
||||||
in `map2map/args.py` or by `m2m.py -h`.
|
Take a look at the examples in `scripts/*.slurm`.
|
||||||
|
For all command line options look at `map2map/args.py` or do `m2m.py -h`.
|
||||||
|
|
||||||
|
|
||||||
### Data
|
### Data
|
||||||
|
|
||||||
|
Put each field in one npy file.
|
||||||
Structure your data to start with the channel axis and then the spatial
|
Structure your data to start with the channel axis and then the spatial
|
||||||
dimensions.
|
dimensions.
|
||||||
Put each sample in one file.
|
For example a 2D vector field of size `64^2` should have shape `(2, 64,
|
||||||
Specify the data path with glob patterns.
|
64)`.
|
||||||
|
Specify the data path with
|
||||||
|
[glob patterns](https://docs.python.org/3/library/glob.html).
|
||||||
|
|
||||||
|
During training, pairs of input and target fields are loaded.
|
||||||
|
Both input and target data can consist of multiple fields, which are
|
||||||
|
then concatenated along the channel axis.
|
||||||
|
If the size of a pair of input and target fields is too large to fit in
|
||||||
|
a GPU, we can crop part of them to form pairs of samples (see `--crop`).
|
||||||
|
Each field can be cropped multiple times, along each dimension,
|
||||||
|
controlled by the spacing between two adjacent crops (see `--step`).
|
||||||
|
The total sample size is the number of input and target pairs multiplied
|
||||||
|
by the number of cropped samples per pair.
|
||||||
|
|
||||||
|
|
||||||
#### Data normalization
|
#### Data normalization
|
||||||
|
|
||||||
Input and target (output) data can be normalized by functions defined in
|
Input and target (output) data can be normalized by functions defined in
|
||||||
`map2map2/data/norms/`.
|
`map2map2/data/norms/`.
|
||||||
|
Also see [Customization](#customization).
|
||||||
|
|
||||||
|
|
||||||
### Model
|
### Model
|
||||||
|
|
||||||
Find the models in `map2map/models/`.
|
Find the models in `map2map/models/`.
|
||||||
Customize the existing models, or add new models there and edit the `__init__.py`.
|
Modify the existing models, or write new models somewhere and then
|
||||||
|
follow [Customization](#customization).
|
||||||
|
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
|
||||||
|
|
||||||
|
### Customization
|
||||||
|
|
||||||
|
Models, criteria, optimizers and data normalizations can be customized
|
||||||
|
without modifying map2map.
|
||||||
|
They can be implemented as callbacks in a user directory which is then
|
||||||
|
passed by `--callback-at`.
|
||||||
|
The default locations are searched first before the callback directory.
|
||||||
|
So be aware of name collisions.
|
||||||
|
|
||||||
|
This approach is good for experimentation.
|
||||||
|
For example, one can play with a model `Bar` in `./foo.py`, by calling
|
||||||
|
`m2m.py` with `--model foo.Bar --callback-at .`
|
||||||
|
|
||||||
|
|
||||||
#### Files generated
|
#### Files generated
|
||||||
|
|
||||||
* `*.out`: job stdout and stderr
|
* `*.out`: job stdout and stderr
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -72,6 +73,11 @@ def add_common_args(parser):
|
|||||||
help='maximum pairs of fields in cache, unlimited by default. '
|
help='maximum pairs of fields in cache, unlimited by default. '
|
||||||
'This only applies to training if not None, '
|
'This only applies to training if not None, '
|
||||||
'in which case the testing cache maxsize is 1')
|
'in which case the testing cache maxsize is 1')
|
||||||
|
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
|
||||||
|
help='directory of custorm code defining callbacks for models, '
|
||||||
|
'norms, criteria, and optimizers. Disabled if not set. '
|
||||||
|
'This is appended to the default locations, '
|
||||||
|
'thus has the lowest priority.')
|
||||||
|
|
||||||
|
|
||||||
def add_train_args(parser):
|
def add_train_args(parser):
|
||||||
|
@ -5,7 +5,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from .norms import import_norm
|
from ..utils import import_attr
|
||||||
|
from . import norms
|
||||||
|
|
||||||
|
|
||||||
class FieldDataset(Dataset):
|
class FieldDataset(Dataset):
|
||||||
@ -38,7 +39,7 @@ class FieldDataset(Dataset):
|
|||||||
This saves CPU RAM but limits stochasticity.
|
This saves CPU RAM but limits stochasticity.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
in_norms=None, tgt_norms=None,
|
in_norms=None, tgt_norms=None, callback_at=None,
|
||||||
augment=False, aug_add=None, aug_mul=None,
|
augment=False, aug_add=None, aug_mul=None,
|
||||||
crop=None, pad=0, scale_factor=1,
|
crop=None, pad=0, scale_factor=1,
|
||||||
cache=False, cache_maxsize=None, div_data=False,
|
cache=False, cache_maxsize=None, div_data=False,
|
||||||
@ -65,13 +66,15 @@ class FieldDataset(Dataset):
|
|||||||
if in_norms is not None:
|
if in_norms is not None:
|
||||||
assert len(in_patterns) == len(in_norms), \
|
assert len(in_patterns) == len(in_norms), \
|
||||||
'numbers of input normalization functions and fields do not match'
|
'numbers of input normalization functions and fields do not match'
|
||||||
in_norms = [import_norm(norm) for norm in in_norms]
|
in_norms = [import_attr(norm, norms.__name__, callback_at)
|
||||||
|
for norm in in_norms]
|
||||||
self.in_norms = in_norms
|
self.in_norms = in_norms
|
||||||
|
|
||||||
if tgt_norms is not None:
|
if tgt_norms is not None:
|
||||||
assert len(tgt_patterns) == len(tgt_norms), \
|
assert len(tgt_patterns) == len(tgt_norms), \
|
||||||
'numbers of target normalization functions and fields do not match'
|
'numbers of target normalization functions and fields do not match'
|
||||||
tgt_norms = [import_norm(norm) for norm in tgt_norms]
|
tgt_norms = [import_attr(norm, norms.__name__, callback_at)
|
||||||
|
for norm in tgt_norms]
|
||||||
self.tgt_norms = tgt_norms
|
self.tgt_norms = tgt_norms
|
||||||
|
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
from importlib import import_module
|
|
||||||
|
|
||||||
from . import cosmology
|
|
||||||
|
|
||||||
|
|
||||||
def import_norm(norm):
|
|
||||||
if callable(norm):
|
|
||||||
return norm
|
|
||||||
|
|
||||||
mod, fun = norm.rsplit('.', 1)
|
|
||||||
mod = import_module('.' + mod, __name__)
|
|
||||||
fun = getattr(mod, fun)
|
|
||||||
return fun
|
|
@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
|
|||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_like
|
from .models import narrow_like
|
||||||
from .state import load_model_state_dict
|
from .utils import import_attr, load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
def test(args):
|
def test(args):
|
||||||
@ -18,6 +18,7 @@ def test(args):
|
|||||||
tgt_patterns=args.test_tgt_patterns,
|
tgt_patterns=args.test_tgt_patterns,
|
||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
|
callback_at=args.callback_at,
|
||||||
augment=False,
|
augment=False,
|
||||||
aug_add=None,
|
aug_add=None,
|
||||||
aug_mul=None,
|
aug_mul=None,
|
||||||
@ -35,9 +36,9 @@ def test(args):
|
|||||||
|
|
||||||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = import_attr(args.model, models.__name__, args.callback_at)
|
||||||
model = model(sum(in_chan), sum(out_chan))
|
model = model(sum(in_chan), sum(out_chan))
|
||||||
criterion = getattr(torch.nn, args.criterion)
|
criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
|
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
|
@ -6,6 +6,7 @@ from pprint import pprint
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.multiprocessing import spawn
|
from torch.multiprocessing import spawn
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
@ -20,7 +21,7 @@ from .models import (narrow_like,
|
|||||||
adv_model_wrapper, adv_criterion_wrapper,
|
adv_model_wrapper, adv_criterion_wrapper,
|
||||||
add_spectral_norm, rm_spectral_norm,
|
add_spectral_norm, rm_spectral_norm,
|
||||||
InstanceNoise)
|
InstanceNoise)
|
||||||
from .state import load_model_state_dict
|
from .utils import import_attr, load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
ckpt_link = 'checkpoint.pth'
|
ckpt_link = 'checkpoint.pth'
|
||||||
@ -62,6 +63,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
tgt_patterns=args.train_tgt_patterns,
|
tgt_patterns=args.train_tgt_patterns,
|
||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
|
callback_at=args.callback_at,
|
||||||
augment=args.augment,
|
augment=args.augment,
|
||||||
aug_add=args.aug_add,
|
aug_add=args.aug_add,
|
||||||
aug_mul=args.aug_mul,
|
aug_mul=args.aug_mul,
|
||||||
@ -100,6 +102,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
tgt_patterns=args.val_tgt_patterns,
|
tgt_patterns=args.val_tgt_patterns,
|
||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
|
callback_at=args.callback_at,
|
||||||
augment=False,
|
augment=False,
|
||||||
aug_add=None,
|
aug_add=None,
|
||||||
aug_mul=None,
|
aug_mul=None,
|
||||||
@ -130,17 +133,17 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = import_attr(args.model, models.__name__, args.callback_at)
|
||||||
model = model(sum(args.in_chan), sum(args.out_chan))
|
model = model(sum(args.in_chan), sum(args.out_chan))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
criterion = getattr(nn, args.criterion)
|
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
|
|
||||||
optimizer = getattr(torch.optim, args.optimizer)
|
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||||
optimizer = optimizer(
|
optimizer = optimizer(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
@ -148,12 +151,12 @@ def gpu_worker(local_rank, node, args):
|
|||||||
betas=(0.5, 0.999),
|
betas=(0.5, 0.999),
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||||
factor=0.1, patience=10, verbose=True)
|
factor=0.1, patience=10, verbose=True)
|
||||||
|
|
||||||
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
|
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
|
||||||
if args.adv:
|
if args.adv:
|
||||||
adv_model = getattr(models, args.adv_model)
|
adv_model = import_attr(args.adv_model, models.__name__, args.callback_at)
|
||||||
adv_model = adv_model_wrapper(adv_model)
|
adv_model = adv_model_wrapper(adv_model)
|
||||||
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
||||||
if args.cgan else sum(args.out_chan), 1)
|
if args.cgan else sum(args.out_chan), 1)
|
||||||
@ -163,19 +166,19 @@ def gpu_worker(local_rank, node, args):
|
|||||||
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
adv_criterion = getattr(nn, args.adv_criterion)
|
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
|
||||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
||||||
adv_criterion.to(device)
|
adv_criterion.to(device)
|
||||||
|
|
||||||
adv_optimizer = getattr(torch.optim, args.optimizer)
|
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||||
adv_optimizer = adv_optimizer(
|
adv_optimizer = adv_optimizer(
|
||||||
adv_model.parameters(),
|
adv_model.parameters(),
|
||||||
lr=args.adv_lr,
|
lr=args.adv_lr,
|
||||||
betas=(0.5, 0.999),
|
betas=(0.5, 0.999),
|
||||||
weight_decay=args.adv_weight_decay,
|
weight_decay=args.adv_weight_decay,
|
||||||
)
|
)
|
||||||
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
adv_scheduler = optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
||||||
factor=0.1, patience=10, verbose=True)
|
factor=0.1, patience=10, verbose=True)
|
||||||
|
|
||||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
from .imp import import_attr
|
||||||
|
from .state import load_model_state_dict
|
41
map2map/utils/imp.py
Normal file
41
map2map/utils/imp.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import os
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
def import_attr(name, pkg, callback_at=None):
|
||||||
|
"""Import attribute. Try package first and then callback directory.
|
||||||
|
|
||||||
|
To use a callback, `name` must contain a module, formatted as 'mod.attr'.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> import_attr('attr', 'pkg1.pkg2')
|
||||||
|
|
||||||
|
tries to import attr from pkg1.pkg2.
|
||||||
|
|
||||||
|
>>> import_attr('mod.attr', 'pkg1.pkg2', 'path/to/cb_dir')
|
||||||
|
|
||||||
|
first tries to import attr from pkg1.pkg2.mod, then from
|
||||||
|
'path/to/cb_dir/mod.py'.
|
||||||
|
"""
|
||||||
|
if name.count('.') == 0:
|
||||||
|
attr = name
|
||||||
|
|
||||||
|
return getattr(importlib.import_module(pkg), attr)
|
||||||
|
else:
|
||||||
|
mod, attr = name.rsplit('.', 1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return getattr(importlib.import_module(pkg + '.' + mod), attr)
|
||||||
|
except (ModuleNotFoundError, AttributeError):
|
||||||
|
if callback_at is None:
|
||||||
|
raise
|
||||||
|
|
||||||
|
callback_at = os.path.join(callback_at, mod + '.py')
|
||||||
|
assert os.path.isfile(callback_at), 'callback file not found'
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(mod, callback_at)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
return getattr(module, attr)
|
@ -1,5 +1,5 @@
|
|||||||
import warnings
|
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
|
|
2
setup.py
2
setup.py
@ -8,7 +8,7 @@ setup(
|
|||||||
author='Yin Li et al.',
|
author='Yin Li et al.',
|
||||||
author_email='eelregit@gmail.com',
|
author_email='eelregit@gmail.com',
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
python_requires='>=3.2',
|
python_requires='>=3.6',
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'torch',
|
'torch',
|
||||||
'numpy',
|
'numpy',
|
||||||
|
Loading…
Reference in New Issue
Block a user