Add multiple pkgs to import_attr
This commit is contained in:
parent
79b28561d5
commit
f3216f4cbd
@ -68,14 +68,14 @@ class FieldDataset(Dataset):
|
|||||||
if in_norms is not None:
|
if in_norms is not None:
|
||||||
assert len(in_patterns) == len(in_norms), \
|
assert len(in_patterns) == len(in_norms), \
|
||||||
'numbers of input normalization functions and fields do not match'
|
'numbers of input normalization functions and fields do not match'
|
||||||
in_norms = [import_attr(norm, norms.__name__, callback_at)
|
in_norms = [import_attr(norm, norms, callback_at=callback_at)
|
||||||
for norm in in_norms]
|
for norm in in_norms]
|
||||||
self.in_norms = in_norms
|
self.in_norms = in_norms
|
||||||
|
|
||||||
if tgt_norms is not None:
|
if tgt_norms is not None:
|
||||||
assert len(tgt_patterns) == len(tgt_norms), \
|
assert len(tgt_patterns) == len(tgt_norms), \
|
||||||
'numbers of target normalization functions and fields do not match'
|
'numbers of target normalization functions and fields do not match'
|
||||||
tgt_norms = [import_attr(norm, norms.__name__, callback_at)
|
tgt_norms = [import_attr(norm, norms, callback_at=callback_at)
|
||||||
for norm in tgt_norms]
|
for norm in tgt_norms]
|
||||||
self.tgt_norms = tgt_norms
|
self.tgt_norms = tgt_norms
|
||||||
|
|
||||||
|
@ -40,9 +40,9 @@ def test(args):
|
|||||||
|
|
||||||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||||
|
|
||||||
model = import_attr(args.model, models.__name__, args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
||||||
criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at)
|
criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
|
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
|
@ -119,18 +119,19 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||||
|
|
||||||
model = import_attr(args.model, models.__name__, args.callback_at)
|
model = import_attr(args.model, models, callback_at=args.callback_at)
|
||||||
model = model(sum(args.in_chan), sum(args.out_chan),
|
model = model(sum(args.in_chan), sum(args.out_chan),
|
||||||
scale_factor=args.scale_factor)
|
scale_factor=args.scale_factor)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
criterion = import_attr(args.criterion, nn.__name__, args.callback_at)
|
criterion = import_attr(args.criterion, nn, models,
|
||||||
|
callback_at=args.callback_at)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
|
|
||||||
optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
optimizer = import_attr(args.optimizer, optim, callback_at=args.callback_at)
|
||||||
optimizer = optimizer(
|
optimizer = optimizer(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
|
@ -2,37 +2,52 @@ import os
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
def import_attr(name, pkg, callback_at=None):
|
def import_attr(name, *pkgs, callback_at=None):
|
||||||
"""Import attribute. Try package first and then callback directory.
|
"""Import attribute. Try package first and then callback directory.
|
||||||
|
|
||||||
To use a callback, `name` must contain a module, formatted as 'mod.attr'.
|
To use a callback, `name` must contain a module, formatted as 'mod.attr'.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> import_attr('attr', 'pkg1.pkg2')
|
>>> import_attr('attr', pkg1.pkg2)
|
||||||
|
|
||||||
tries to import attr from pkg1.pkg2.
|
tries to import attr from pkg1.pkg2.
|
||||||
|
|
||||||
>>> import_attr('mod.attr', 'pkg1.pkg2', 'path/to/cb_dir')
|
>>> import_attr('mod.attr', pkg1.pkg2, pkg3, callback_at='path/to/cb_dir')
|
||||||
|
|
||||||
first tries to import attr from pkg1.pkg2.mod, then from
|
first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally
|
||||||
'path/to/cb_dir/mod.py'.
|
from 'path/to/cb_dir/mod.py'.
|
||||||
"""
|
"""
|
||||||
if name.count('.') == 0:
|
if name.count('.') == 0:
|
||||||
attr = name
|
attr = name
|
||||||
|
|
||||||
return getattr(importlib.import_module(pkg), attr)
|
errors = []
|
||||||
|
|
||||||
|
for pkg in pkgs:
|
||||||
|
try:
|
||||||
|
return getattr(importlib.import_module(pkg.__name__), attr)
|
||||||
|
except (ModuleNotFoundError, AttributeError) as e:
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
|
raise Exception(errors)
|
||||||
else:
|
else:
|
||||||
mod, attr = name.rsplit('.', 1)
|
mod, attr = name.rsplit('.', 1)
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for pkg in pkgs:
|
||||||
try:
|
try:
|
||||||
return getattr(importlib.import_module(pkg + '.' + mod), attr)
|
return getattr(
|
||||||
|
importlib.import_module(pkg.__name__ + '.' + mod), attr)
|
||||||
except (ModuleNotFoundError, AttributeError):
|
except (ModuleNotFoundError, AttributeError):
|
||||||
|
errors.append(e)
|
||||||
|
|
||||||
if callback_at is None:
|
if callback_at is None:
|
||||||
raise
|
raise Exception(errors)
|
||||||
|
|
||||||
callback_at = os.path.join(callback_at, mod + '.py')
|
callback_at = os.path.join(callback_at, mod + '.py')
|
||||||
assert os.path.isfile(callback_at), 'callback file not found'
|
if not os.path.isfile(callback_at):
|
||||||
|
raise FileNotFoundError('callback file not found')
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(mod, callback_at)
|
spec = importlib.util.spec_from_file_location(mod, callback_at)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
Loading…
Reference in New Issue
Block a user