diff --git a/map2map/data/fields.py b/map2map/data/fields.py index e023375..6ed0750 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -68,14 +68,14 @@ class FieldDataset(Dataset): if in_norms is not None: assert len(in_patterns) == len(in_norms), \ 'numbers of input normalization functions and fields do not match' - in_norms = [import_attr(norm, norms.__name__, callback_at) + in_norms = [import_attr(norm, norms, callback_at=callback_at) for norm in in_norms] self.in_norms = in_norms if tgt_norms is not None: assert len(tgt_patterns) == len(tgt_norms), \ 'numbers of target normalization functions and fields do not match' - tgt_norms = [import_attr(norm, norms.__name__, callback_at) + tgt_norms = [import_attr(norm, norms, callback_at=callback_at) for norm in tgt_norms] self.tgt_norms = tgt_norms diff --git a/map2map/test.py b/map2map/test.py index 0e26521..841294b 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -40,9 +40,9 @@ def test(args): in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan - model = import_attr(args.model, models.__name__, args.callback_at) + model = import_attr(args.model, models, callback_at=args.callback_at) model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor) - criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at) + criterion = import_attr(args.criterion, torch.nn, callback_at=args.callback_at) criterion = criterion() device = torch.device('cpu') diff --git a/map2map/train.py b/map2map/train.py index e88a0bb..04cd923 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -119,18 +119,19 @@ def gpu_worker(local_rank, node, args): args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan - model = import_attr(args.model, models.__name__, args.callback_at) + model = import_attr(args.model, models, callback_at=args.callback_at) model = model(sum(args.in_chan), sum(args.out_chan), scale_factor=args.scale_factor) model.to(device) model = DistributedDataParallel(model, device_ids=[device], process_group=dist.new_group()) - criterion = import_attr(args.criterion, nn.__name__, args.callback_at) + criterion = import_attr(args.criterion, nn, models, + callback_at=args.callback_at) criterion = criterion() criterion.to(device) - optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at) + optimizer = import_attr(args.optimizer, optim, callback_at=args.callback_at) optimizer = optimizer( model.parameters(), lr=args.lr, diff --git a/map2map/utils/imp.py b/map2map/utils/imp.py index 41efc09..b79a77b 100644 --- a/map2map/utils/imp.py +++ b/map2map/utils/imp.py @@ -2,37 +2,52 @@ import os import importlib -def import_attr(name, pkg, callback_at=None): +def import_attr(name, *pkgs, callback_at=None): """Import attribute. Try package first and then callback directory. To use a callback, `name` must contain a module, formatted as 'mod.attr'. Examples -------- - >>> import_attr('attr', 'pkg1.pkg2') + >>> import_attr('attr', pkg1.pkg2) tries to import attr from pkg1.pkg2. - >>> import_attr('mod.attr', 'pkg1.pkg2', 'path/to/cb_dir') + >>> import_attr('mod.attr', pkg1.pkg2, pkg3, callback_at='path/to/cb_dir') - first tries to import attr from pkg1.pkg2.mod, then from - 'path/to/cb_dir/mod.py'. + first tries to import attr from pkg1.pkg2.mod, then from pkg3.mod, finally + from 'path/to/cb_dir/mod.py'. """ if name.count('.') == 0: attr = name - return getattr(importlib.import_module(pkg), attr) + errors = [] + + for pkg in pkgs: + try: + return getattr(importlib.import_module(pkg.__name__), attr) + except (ModuleNotFoundError, AttributeError) as e: + errors.append(e) + + raise Exception(errors) else: mod, attr = name.rsplit('.', 1) - try: - return getattr(importlib.import_module(pkg + '.' + mod), attr) - except (ModuleNotFoundError, AttributeError): - if callback_at is None: - raise + errors = [] + + for pkg in pkgs: + try: + return getattr( + importlib.import_module(pkg.__name__ + '.' + mod), attr) + except (ModuleNotFoundError, AttributeError): + errors.append(e) + + if callback_at is None: + raise Exception(errors) callback_at = os.path.join(callback_at, mod + '.py') - assert os.path.isfile(callback_at), 'callback file not found' + if not os.path.isfile(callback_at): + raise FileNotFoundError('callback file not found') spec = importlib.util.spec_from_file_location(mod, callback_at) module = importlib.util.module_from_spec(spec)