Fix error pickling normalization function in multiprocessing dataloading
By avoiding importing functions as class attributes but importing before calling, and using cache in the importing function Thanks to help by Yu Feng
This commit is contained in:
parent
0e0ad8f071
commit
6c62cb09db
@ -68,17 +68,15 @@ class FieldDataset(Dataset):
|
|||||||
if in_norms is not None:
|
if in_norms is not None:
|
||||||
assert len(in_patterns) == len(in_norms), \
|
assert len(in_patterns) == len(in_norms), \
|
||||||
'numbers of input normalization functions and fields do not match'
|
'numbers of input normalization functions and fields do not match'
|
||||||
in_norms = [import_attr(norm, norms, callback_at=callback_at)
|
|
||||||
for norm in in_norms]
|
|
||||||
self.in_norms = in_norms
|
self.in_norms = in_norms
|
||||||
|
|
||||||
if tgt_norms is not None:
|
if tgt_norms is not None:
|
||||||
assert len(tgt_patterns) == len(tgt_norms), \
|
assert len(tgt_patterns) == len(tgt_norms), \
|
||||||
'numbers of target normalization functions and fields do not match'
|
'numbers of target normalization functions and fields do not match'
|
||||||
tgt_norms = [import_attr(norm, norms, callback_at=callback_at)
|
|
||||||
for norm in tgt_norms]
|
|
||||||
self.tgt_norms = tgt_norms
|
self.tgt_norms = tgt_norms
|
||||||
|
|
||||||
|
self.callback_at = callback_at
|
||||||
|
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
if self.ndim == 1 and self.augment:
|
if self.ndim == 1 and self.augment:
|
||||||
raise ValueError('cannot augment 1D fields')
|
raise ValueError('cannot augment 1D fields')
|
||||||
@ -151,9 +149,11 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
if self.in_norms is not None:
|
if self.in_norms is not None:
|
||||||
for norm, x in zip(self.in_norms, in_fields):
|
for norm, x in zip(self.in_norms, in_fields):
|
||||||
|
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||||
norm(x)
|
norm(x)
|
||||||
if self.tgt_norms is not None:
|
if self.tgt_norms is not None:
|
||||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||||
|
norm = import_attr(norm, norms, callback_at=self.callback_at)
|
||||||
norm(x)
|
norm(x)
|
||||||
|
|
||||||
if self.augment:
|
if self.augment:
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import importlib
|
import importlib
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
def import_attr(name, *pkgs, callback_at=None):
|
def import_attr(name, *pkgs, callback_at=None):
|
||||||
"""Import attribute. Try package first and then callback directory.
|
"""Import attribute. Try package first and then callback directory.
|
||||||
|
|
||||||
@ -50,6 +52,9 @@ def import_attr(name, *pkgs, callback_at=None):
|
|||||||
if not os.path.isfile(callback_at):
|
if not os.path.isfile(callback_at):
|
||||||
raise FileNotFoundError('callback file not found')
|
raise FileNotFoundError('callback file not found')
|
||||||
|
|
||||||
|
if mod in sys.modules:
|
||||||
|
return getattr(sys.modules[mod], attr)
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(mod, callback_at)
|
spec = importlib.util.spec_from_file_location(mod, callback_at)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
sys.modules[mod] = module
|
sys.modules[mod] = module
|
||||||
|
Loading…
Reference in New Issue
Block a user