Fix error pickling normalization function in multiprocessing dataloading

By avoiding importing functions as class attributes but importing before
calling, and using cache in the importing function

Thanks to help by Yu Feng
This commit is contained in:
Yin Li 2020-09-09 14:40:11 -04:00
parent 0e0ad8f071
commit 6c62cb09db
2 changed files with 9 additions and 4 deletions

View file

@ -68,17 +68,15 @@ class FieldDataset(Dataset):
if in_norms is not None:
assert len(in_patterns) == len(in_norms), \
'numbers of input normalization functions and fields do not match'
in_norms = [import_attr(norm, norms, callback_at=callback_at)
for norm in in_norms]
self.in_norms = in_norms
if tgt_norms is not None:
assert len(tgt_patterns) == len(tgt_norms), \
'numbers of target normalization functions and fields do not match'
tgt_norms = [import_attr(norm, norms, callback_at=callback_at)
for norm in tgt_norms]
self.tgt_norms = tgt_norms
self.callback_at = callback_at
self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
@ -151,9 +149,11 @@ class FieldDataset(Dataset):
if self.in_norms is not None:
for norm, x in zip(self.in_norms, in_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x)
if self.tgt_norms is not None:
for norm, x in zip(self.tgt_norms, tgt_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x)
if self.augment:

View file

@ -1,8 +1,10 @@
import os
import sys
import importlib
from functools import lru_cache
@lru_cache(maxsize=None)
def import_attr(name, *pkgs, callback_at=None):
"""Import attribute. Try package first and then callback directory.
@ -50,6 +52,9 @@ def import_attr(name, *pkgs, callback_at=None):
if not os.path.isfile(callback_at):
raise FileNotFoundError('callback file not found')
if mod in sys.modules:
return getattr(sys.modules[mod], attr)
spec = importlib.util.spec_from_file_location(mod, callback_at)
module = importlib.util.module_from_spec(spec)
sys.modules[mod] = module