Add callback loading mechanism

This commit is contained in:
Yin Li 2020-06-14 17:59:31 -04:00
parent 5bb2a19933
commit 90a0d6e0f8
10 changed files with 110 additions and 36 deletions

View file

@ -1,3 +1,4 @@
import os
import argparse
import warnings
@ -72,6 +73,11 @@ def add_common_args(parser):
help='maximum pairs of fields in cache, unlimited by default. '
'This only applies to training if not None, '
'in which case the testing cache maxsize is 1')
parser.add_argument('--callback-at', type=lambda s: os.path.abspath(s),
help='directory of custorm code defining callbacks for models, '
'norms, criteria, and optimizers. Disabled if not set. '
'This is appended to the default locations, '
'thus has the lowest priority.')
def add_train_args(parser):

View file

@ -5,7 +5,8 @@ import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from .norms import import_norm
from ..utils import import_attr
from . import norms
class FieldDataset(Dataset):
@ -38,7 +39,7 @@ class FieldDataset(Dataset):
This saves CPU RAM but limits stochasticity.
"""
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None,
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None,
crop=None, pad=0, scale_factor=1,
cache=False, cache_maxsize=None, div_data=False,
@ -65,13 +66,15 @@ class FieldDataset(Dataset):
if in_norms is not None:
assert len(in_patterns) == len(in_norms), \
'numbers of input normalization functions and fields do not match'
in_norms = [import_norm(norm) for norm in in_norms]
in_norms = [import_attr(norm, norms.__name__, callback_at)
for norm in in_norms]
self.in_norms = in_norms
if tgt_norms is not None:
assert len(tgt_patterns) == len(tgt_norms), \
'numbers of target normalization functions and fields do not match'
tgt_norms = [import_norm(norm) for norm in tgt_norms]
tgt_norms = [import_attr(norm, norms.__name__, callback_at)
for norm in tgt_norms]
self.tgt_norms = tgt_norms
self.augment = augment

View file

@ -1,13 +0,0 @@
from importlib import import_module
from . import cosmology
def import_norm(norm):
if callable(norm):
return norm
mod, fun = norm.rsplit('.', 1)
mod = import_module('.' + mod, __name__)
fun = getattr(mod, fun)
return fun

View file

@ -6,7 +6,7 @@ from torch.utils.data import DataLoader
from .data import FieldDataset
from . import models
from .models import narrow_like
from .state import load_model_state_dict
from .utils import import_attr, load_model_state_dict
def test(args):
@ -18,6 +18,7 @@ def test(args):
tgt_patterns=args.test_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_add=None,
aug_mul=None,
@ -35,9 +36,9 @@ def test(args):
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
model = getattr(models, args.model)
model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(in_chan), sum(out_chan))
criterion = getattr(torch.nn, args.criterion)
criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at)
criterion = criterion()
device = torch.device('cpu')

View file

@ -6,6 +6,7 @@ from pprint import pprint
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel
@ -20,7 +21,7 @@ from .models import (narrow_like,
adv_model_wrapper, adv_criterion_wrapper,
add_spectral_norm, rm_spectral_norm,
InstanceNoise)
from .state import load_model_state_dict
from .utils import import_attr, load_model_state_dict
ckpt_link = 'checkpoint.pth'
@ -62,6 +63,7 @@ def gpu_worker(local_rank, node, args):
tgt_patterns=args.train_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=args.augment,
aug_add=args.aug_add,
aug_mul=args.aug_mul,
@ -100,6 +102,7 @@ def gpu_worker(local_rank, node, args):
tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_add=None,
aug_mul=None,
@ -130,17 +133,17 @@ def gpu_worker(local_rank, node, args):
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = getattr(models, args.model)
model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan))
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group())
criterion = getattr(nn, args.criterion)
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
criterion = criterion()
criterion.to(device)
optimizer = getattr(torch.optim, args.optimizer)
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
optimizer = optimizer(
model.parameters(),
lr=args.lr,
@ -148,12 +151,12 @@ def gpu_worker(local_rank, node, args):
betas=(0.5, 0.999),
weight_decay=args.weight_decay,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.1, patience=10, verbose=True)
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
if args.adv:
adv_model = getattr(models, args.adv_model)
adv_model = import_attr(args.adv_model, models.__name__, args.callback_at)
adv_model = adv_model_wrapper(adv_model)
adv_model = adv_model(sum(args.in_chan + args.out_chan)
if args.cgan else sum(args.out_chan), 1)
@ -163,19 +166,19 @@ def gpu_worker(local_rank, node, args):
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
process_group=dist.new_group())
adv_criterion = getattr(nn, args.adv_criterion)
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
adv_criterion.to(device)
adv_optimizer = getattr(torch.optim, args.optimizer)
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
adv_optimizer = adv_optimizer(
adv_model.parameters(),
lr=args.adv_lr,
betas=(0.5, 0.999),
weight_decay=args.adv_weight_decay,
)
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
adv_scheduler = optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
factor=0.1, patience=10, verbose=True)
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)

View file

@ -0,0 +1,2 @@
from .imp import import_attr
from .state import load_model_state_dict

41
map2map/utils/imp.py Normal file
View file

@ -0,0 +1,41 @@
import os
import importlib
def import_attr(name, pkg, callback_at=None):
"""Import attribute. Try package first and then callback directory.
To use a callback, `name` must contain a module, formatted as 'mod.attr'.
Examples
--------
>>> import_attr('attr', 'pkg1.pkg2')
tries to import attr from pkg1.pkg2.
>>> import_attr('mod.attr', 'pkg1.pkg2', 'path/to/cb_dir')
first tries to import attr from pkg1.pkg2.mod, then from
'path/to/cb_dir/mod.py'.
"""
if name.count('.') == 0:
attr = name
return getattr(importlib.import_module(pkg), attr)
else:
mod, attr = name.rsplit('.', 1)
try:
return getattr(importlib.import_module(pkg + '.' + mod), attr)
except (ModuleNotFoundError, AttributeError):
if callback_at is None:
raise
callback_at = os.path.join(callback_at, mod + '.py')
assert os.path.isfile(callback_at), 'callback file not found'
spec = importlib.util.spec_from_file_location(mod, callback_at)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, attr)

View file

@ -1,5 +1,5 @@
import warnings
import sys
import warnings
from pprint import pformat