Fix error pickling normalization function in multiprocessing dataloading

By avoiding importing functions as class attributes but importing before
calling, and using cache in the importing function

Thanks to help by Yu Feng
This commit is contained in:
Yin Li 2020-09-09 14:40:11 -04:00
parent 0e0ad8f071
commit 6c62cb09db
2 changed files with 9 additions and 4 deletions

View File

@ -68,17 +68,15 @@ class FieldDataset(Dataset):
if in_norms is not None: if in_norms is not None:
assert len(in_patterns) == len(in_norms), \ assert len(in_patterns) == len(in_norms), \
'numbers of input normalization functions and fields do not match' 'numbers of input normalization functions and fields do not match'
in_norms = [import_attr(norm, norms, callback_at=callback_at)
for norm in in_norms]
self.in_norms = in_norms self.in_norms = in_norms
if tgt_norms is not None: if tgt_norms is not None:
assert len(tgt_patterns) == len(tgt_norms), \ assert len(tgt_patterns) == len(tgt_norms), \
'numbers of target normalization functions and fields do not match' 'numbers of target normalization functions and fields do not match'
tgt_norms = [import_attr(norm, norms, callback_at=callback_at)
for norm in tgt_norms]
self.tgt_norms = tgt_norms self.tgt_norms = tgt_norms
self.callback_at = callback_at
self.augment = augment self.augment = augment
if self.ndim == 1 and self.augment: if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields') raise ValueError('cannot augment 1D fields')
@ -151,9 +149,11 @@ class FieldDataset(Dataset):
if self.in_norms is not None: if self.in_norms is not None:
for norm, x in zip(self.in_norms, in_fields): for norm, x in zip(self.in_norms, in_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x) norm(x)
if self.tgt_norms is not None: if self.tgt_norms is not None:
for norm, x in zip(self.tgt_norms, tgt_fields): for norm, x in zip(self.tgt_norms, tgt_fields):
norm = import_attr(norm, norms, callback_at=self.callback_at)
norm(x) norm(x)
if self.augment: if self.augment:

View File

@ -1,8 +1,10 @@
import os import os
import sys import sys
import importlib import importlib
from functools import lru_cache
@lru_cache(maxsize=None)
def import_attr(name, *pkgs, callback_at=None): def import_attr(name, *pkgs, callback_at=None):
"""Import attribute. Try package first and then callback directory. """Import attribute. Try package first and then callback directory.
@ -50,6 +52,9 @@ def import_attr(name, *pkgs, callback_at=None):
if not os.path.isfile(callback_at): if not os.path.isfile(callback_at):
raise FileNotFoundError('callback file not found') raise FileNotFoundError('callback file not found')
if mod in sys.modules:
return getattr(sys.modules[mod], attr)
spec = importlib.util.spec_from_file_location(mod, callback_at) spec = importlib.util.spec_from_file_location(mod, callback_at)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
sys.modules[mod] = module sys.modules[mod] = module