Fix warnings on incompatible keys
This commit is contained in:
parent
53ed5a91f4
commit
e383ec3977
@ -1,9 +1,15 @@
|
||||
import warnings
|
||||
import sys
|
||||
from pprint import pformat
|
||||
|
||||
|
||||
def load_model_state_dict(model, state_dict, strict=True):
|
||||
bad_keys = model.load_state_dict(state_dict, strict)
|
||||
|
||||
if bad_keys.missing_keys or bad_keys.unexpected_keys:
|
||||
warnings.warn(pformat(bad_keys))
|
||||
if len(bad_keys.missing_keys) > 0:
|
||||
warnings.warn('Missing keys in state_dict:\n{}'.format(
|
||||
pformat(bad_keys.missing_keys)))
|
||||
if len(bad_keys.unexpected_keys) > 0:
|
||||
warnings.warn('Unexpected keys in state_dict:\n{}'.format(
|
||||
pformat(bad_keys.unexpected_keys)))
|
||||
sys.stderr.flush()
|
||||
|
Loading…
Reference in New Issue
Block a user