Fix error pickling normalization function in multiprocessing dataloading
By avoiding importing functions as class attributes but importing before calling, and using cache in the importing function Thanks to help by Yu Feng
This commit is contained in:
parent
0e0ad8f071
commit
6c62cb09db
2 changed files with 9 additions and 4 deletions
|
@ -68,17 +68,15 @@ class FieldDataset(Dataset):
|
|||
if in_norms is not None:
|
||||
assert len(in_patterns) == len(in_norms), \
|
||||
'numbers of input normalization functions and fields do not match'
|
||||
in_norms = [import_attr(norm, norms, callback_at=callback_at)
|
||||
for norm in in_norms]
|
||||
self.in_norms = in_norms
|
||||
|
||||
if tgt_norms is not None:
|
||||
assert len(tgt_patterns) == len(tgt_norms), \
|
||||
'numbers of target normalization functions and fields do not match'
|
||||
tgt_norms = [import_attr(norm, norms, callback_at=callback_at)
|
||||
for norm in tgt_norms]
|
||||
self.tgt_norms = tgt_norms
|
||||
|
||||
self.callback_at = callback_at
|
||||
|
||||
self.augment = augment
|
||||
if self.ndim == 1 and self.augment:
|
||||
raise ValueError('cannot augment 1D fields')
|
||||
|
@ -151,9 +149,11 @@ class FieldDataset(Dataset):
|
|||
|
||||
if self.in_norms is not None:
|
||||
for norm, x in zip(self.in_norms, in_fields):
|
||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||
norm(x)
|
||||
if self.tgt_norms is not None:
|
||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||
norm(x)
|
||||
|
||||
if self.augment:
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
import sys
|
||||
import importlib
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def import_attr(name, *pkgs, callback_at=None):
|
||||
"""Import attribute. Try package first and then callback directory.
|
||||
|
||||
|
@ -50,6 +52,9 @@ def import_attr(name, *pkgs, callback_at=None):
|
|||
if not os.path.isfile(callback_at):
|
||||
raise FileNotFoundError('callback file not found')
|
||||
|
||||
if mod in sys.modules:
|
||||
return getattr(sys.modules[mod], attr)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(mod, callback_at)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[mod] = module
|
||||
|
|
Loading…
Reference in a new issue