Fix warnings on incompatible keys

This commit is contained in:
Yin Li 2020-02-13 14:39:37 -05:00
parent 53ed5a91f4
commit e383ec3977

View File

@ -1,9 +1,15 @@
import warnings import warnings
import sys
from pprint import pformat from pprint import pformat
def load_model_state_dict(model, state_dict, strict=True): def load_model_state_dict(model, state_dict, strict=True):
bad_keys = model.load_state_dict(state_dict, strict) bad_keys = model.load_state_dict(state_dict, strict)
if bad_keys.missing_keys or bad_keys.unexpected_keys: if len(bad_keys.missing_keys) > 0:
warnings.warn(pformat(bad_keys)) warnings.warn('Missing keys in state_dict:\n{}'.format(
pformat(bad_keys.missing_keys)))
if len(bad_keys.unexpected_keys) > 0:
warnings.warn('Unexpected keys in state_dict:\n{}'.format(
pformat(bad_keys.unexpected_keys)))
sys.stderr.flush()