Add non-strict model state loading

This commit is contained in:
Yin Li 2020-02-09 17:32:09 -05:00
parent f831afbccf
commit 09493ad4ec
4 changed files with 26 additions and 10 deletions

View File

@ -38,6 +38,9 @@ def add_common_args(parser):
help='model criterion from torch.nn') help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str, parser.add_argument('--load-state', default='', type=str,
help='path to load the states of model, optimizer, rng, etc.') help='path to load the states of model, optimizer, rng, etc.')
parser.add_argument('--load-state-non-strict', action='store_false',
help='allow incompatible keys when loading model states',
dest='load_state_strict')
parser.add_argument('--batches', default=1, type=int, parser.add_argument('--batches', default=1, type=int,
help='mini-batch size, per GPU in training or in total in testing') help='mini-batch size, per GPU in training or in total in testing')

9
map2map/state.py Normal file
View File

@ -0,0 +1,9 @@
import warnings
from pprint import pprint
def load_model_state_dict(model, state_dict, strict=True):
bad_keys = model.load_state_dict(state_dict, strict)
if bad_keys.missing_keys or bad_keys.unexpected_keys:
warnings.warn(pprint(repr(bad_keys)))

View File

@ -6,6 +6,7 @@ from torch.utils.data import DataLoader
from .data import FieldDataset from .data import FieldDataset
from . import models from . import models
from .models import narrow_like from .models import narrow_like
from .state import load_model_state_dict
def test(args): def test(args):
@ -39,7 +40,7 @@ def test(args):
device = torch.device('cpu') device = torch.device('cpu')
state = torch.load(args.load_state, map_location=device) state = torch.load(args.load_state, map_location=device)
model.load_state_dict(state['model']) load_model_state_dict(model, state['model'], strict=args.load_state_strict)
print('model state at epoch {} loaded from {}'.format( print('model state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state)) state['epoch'], args.load_state))
del state del state

View File

@ -15,6 +15,7 @@ from .data.figures import fig3d
from . import models from . import models
from .models import narrow_like from .models import narrow_like
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
from .state import load_model_state_dict
def node_worker(args): def node_worker(args):
@ -161,9 +162,11 @@ def gpu_worker(local_rank, args):
state = torch.load(args.load_state, map_location=args.device) state = torch.load(args.load_state, map_location=args.device)
args.start_epoch = state['epoch'] args.start_epoch = state['epoch']
args.adv_delay += args.start_epoch args.adv_delay += args.start_epoch
model.module.load_state_dict(state['model']) load_model_state_dict(model.module, state['model'],
strict=args.load_state_strict)
if 'adv_model' in state and args.adv: if 'adv_model' in state and args.adv:
adv_model.module.load_state_dict(state['adv_model']) load_model_state_dict(adv_model.module, state['adv_model'],
strict=args.load_state_strict)
torch.set_rng_state(state['rng'].cpu()) # move rng state back torch.set_rng_state(state['rng'].cpu()) # move rng state back
if args.rank == 0: if args.rank == 0:
min_loss = state['min_loss'] min_loss = state['min_loss']
@ -220,8 +223,8 @@ def gpu_worker(local_rank, args):
print(end='', flush=True) print(end='', flush=True)
args.logger.close() args.logger.close()
is_best = min_loss is None or epoch_loss[0] < min_loss[0] good = min_loss is None or epoch_loss[0] < min_loss[0]
if is_best and epoch >= args.adv_delay: if good and epoch >= args.adv_delay:
min_loss = epoch_loss min_loss = epoch_loss
state = { state = {
@ -235,14 +238,14 @@ def gpu_worker(local_rank, args):
'adv_model': adv_model.module.state_dict(), 'adv_model': adv_model.module.state_dict(),
}) })
ckpt_file = 'checkpoint.pth' ckpt_file = 'checkpoint.pth'
best_file = 'best_model_{}.pth' state_file = 'state_{}.pth'
torch.save(state, ckpt_file) torch.save(state, ckpt_file)
del state del state
if is_best: if good:
shutil.copyfile(ckpt_file, best_file.format(epoch + 1)) shutil.copyfile(ckpt_file, state_file.format(epoch + 1))
#if os.path.isfile(best_file.format(epoch)): #if os.path.isfile(state_file.format(epoch)):
# os.remove(best_file.format(epoch)) # os.remove(state_file.format(epoch))
dist.destroy_process_group() dist.destroy_process_group()