Remove model loading hack in testing

This commit is contained in:
Yin Li 2020-01-20 21:32:02 -05:00
parent 2b7e559910
commit 6e48905cc0

View File

@ -32,12 +32,6 @@ def test(args):
device = torch.device('cpu')
state = torch.load(args.load_state, map_location=device)
# from collections import OrderedDict
# model_state = OrderedDict()
# for k, v in state['model'].items():
# model_k = k.replace('module.', '', 1) # FIXME
# model_state[model_k] = v
# model.load_state_dict(model_state)
model.load_state_dict(state['model'])
print('model state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state))