Add multiple pkgs to import_attr

This commit is contained in:
Yin Li 2020-08-20 21:32:29 -04:00
parent 79b28561d5
commit f3216f4cbd
4 changed files with 35 additions and 19 deletions

View File

@ -68,14 +68,14 @@ class FieldDataset(Dataset):
if in_norms is not None: if in_norms is not None:
assert len(in_patterns) == len(in_norms), \ assert len(in_patterns) == len(in_norms), \
'numbers of input normalization functions and fields do not match' 'numbers of input normalization functions and fields do not match'
in_norms = [import_attr(norm, norms.__name__, callback_at) in_norms = [import_attr(norm, norms, callback_at=callback_at)
for norm in in_norms] for norm in in_norms]
self.in_norms = in_norms self.in_norms = in_norms
if tgt_norms is not None: if tgt_norms is not None:
assert len(tgt_patterns) == len(tgt_norms), \ assert len(tgt_patterns) == len(tgt_norms), \
'numbers of target normalization functions and fields do not match' 'numbers of target normalization functions and fields do not match'
tgt_norms = [import_attr(norm, norms.__name__, callback_at) tgt_norms = [import_attr(norm, norms, callback_at=callback_at)
for norm in tgt_norms] for norm in tgt_norms]
self.tgt_norms = tgt_norms self.tgt_norms = tgt_norms

View File

@ -40,9 +40,9 @@ def test(args):
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
model = import_attr(args.model, models.__name__, args.callback_at) model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor) model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at) criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
criterion = criterion() criterion = criterion()
device = torch.device('cpu') device = torch.device('cpu')

View File

@ -119,18 +119,19 @@ def gpu_worker(local_rank, node, args):
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = import_attr(args.model, models.__name__, args.callback_at) model = import_attr(args.model, models, callback_at=args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan), model = model(sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor) scale_factor=args.scale_factor)
model.to(device) model.to(device)
model = DistributedDataParallel(model, device_ids=[device], model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
criterion = import_attr(args.criterion, nn.__name__, args.callback_at) criterion = import_attr(args.criterion, nn, models,
callback_at=args.callback_at)
criterion = criterion() criterion = criterion()
criterion.to(device) criterion.to(device)
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at) optimizer = import_attr(args.optimizer, optim, callback_at=args.callback_at)
optimizer = optimizer( optimizer = optimizer(
model.parameters(), model.parameters(),
lr=args.lr, lr=args.lr,

View File

@ -2,37 +2,52 @@ import os
import importlib import importlib
def import_attr(name, pkg, callback_at=None): def import_attr(name, *pkgs, callback_at=None):
"""Import attribute. Try package first and then callback directory. """Import attribute. Try package first and then callback directory.
To use a callback, `name` must contain a module, formatted as 'mod.attr'. To use a callback, `name` must contain a module, formatted as 'mod.attr'.
Examples Examples
-------- --------
>>> import_attr('attr', 'pkg1.pkg2') >>> import_attr('attr', pkg1.pkg2)
tries to import attr from pkg1.pkg2. tries to import attr from pkg1.pkg2.
>>> import_attr('mod.attr', 'pkg1.pkg2', 'path/to/cb_dir') >>> import_attr('mod.attr', pkg1.pkg2, pkg3, callback_at='path/to/cb_dir')
first tries to import attr from pkg1.pkg2.mod, then from first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally
'path/to/cb_dir/mod.py'. from 'path/to/cb_dir/mod.py'.
""" """
if name.count('.') == 0: if name.count('.') == 0:
attr = name attr = name
return getattr(importlib.import_module(pkg), attr) errors = []
for pkg in pkgs:
try:
return getattr(importlib.import_module(pkg.__name__), attr)
except (ModuleNotFoundError, AttributeError) as e:
errors.append(e)
raise Exception(errors)
else: else:
mod, attr = name.rsplit('.', 1) mod, attr = name.rsplit('.', 1)
try: errors = []
return getattr(importlib.import_module(pkg + '.' + mod), attr)
except (ModuleNotFoundError, AttributeError): for pkg in pkgs:
if callback_at is None: try:
raise return getattr(
importlib.import_module(pkg.__name__ + '.' + mod), attr)
except (ModuleNotFoundError, AttributeError):
errors.append(e)
if callback_at is None:
raise Exception(errors)
callback_at = os.path.join(callback_at, mod + '.py') callback_at = os.path.join(callback_at, mod + '.py')
assert os.path.isfile(callback_at), 'callback file not found' if not os.path.isfile(callback_at):
raise FileNotFoundError('callback file not found')
spec = importlib.util.spec_from_file_location(mod, callback_at) spec = importlib.util.spec_from_file_location(mod, callback_at)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)