Add non-strict model state loading
This commit is contained in:
parent
f831afbccf
commit
09493ad4ec
4 changed files with 26 additions and 10 deletions
|
@ -38,6 +38,9 @@ def add_common_args(parser):
|
|||
help='model criterion from torch.nn')
|
||||
parser.add_argument('--load-state', default='', type=str,
|
||||
help='path to load the states of model, optimizer, rng, etc.')
|
||||
parser.add_argument('--load-state-non-strict', action='store_false',
|
||||
help='allow incompatible keys when loading model states',
|
||||
dest='load_state_strict')
|
||||
|
||||
parser.add_argument('--batches', default=1, type=int,
|
||||
help='mini-batch size, per GPU in training or in total in testing')
|
||||
|
|
9
map2map/state.py
Normal file
9
map2map/state.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
import warnings
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
def load_model_state_dict(model, state_dict, strict=True):
|
||||
bad_keys = model.load_state_dict(state_dict, strict)
|
||||
|
||||
if bad_keys.missing_keys or bad_keys.unexpected_keys:
|
||||
warnings.warn(pprint(repr(bad_keys)))
|
|
@ -6,6 +6,7 @@ from torch.utils.data import DataLoader
|
|||
from .data import FieldDataset
|
||||
from . import models
|
||||
from .models import narrow_like
|
||||
from .state import load_model_state_dict
|
||||
|
||||
|
||||
def test(args):
|
||||
|
@ -39,7 +40,7 @@ def test(args):
|
|||
|
||||
device = torch.device('cpu')
|
||||
state = torch.load(args.load_state, map_location=device)
|
||||
model.load_state_dict(state['model'])
|
||||
load_model_state_dict(model, state['model'], strict=args.load_state_strict)
|
||||
print('model state at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state))
|
||||
del state
|
||||
|
|
|
@ -15,6 +15,7 @@ from .data.figures import fig3d
|
|||
from . import models
|
||||
from .models import narrow_like
|
||||
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||
from .state import load_model_state_dict
|
||||
|
||||
|
||||
def node_worker(args):
|
||||
|
@ -161,9 +162,11 @@ def gpu_worker(local_rank, args):
|
|||
state = torch.load(args.load_state, map_location=args.device)
|
||||
args.start_epoch = state['epoch']
|
||||
args.adv_delay += args.start_epoch
|
||||
model.module.load_state_dict(state['model'])
|
||||
load_model_state_dict(model.module, state['model'],
|
||||
strict=args.load_state_strict)
|
||||
if 'adv_model' in state and args.adv:
|
||||
adv_model.module.load_state_dict(state['adv_model'])
|
||||
load_model_state_dict(adv_model.module, state['adv_model'],
|
||||
strict=args.load_state_strict)
|
||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||
if args.rank == 0:
|
||||
min_loss = state['min_loss']
|
||||
|
@ -220,8 +223,8 @@ def gpu_worker(local_rank, args):
|
|||
print(end='', flush=True)
|
||||
args.logger.close()
|
||||
|
||||
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||
if is_best and epoch >= args.adv_delay:
|
||||
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||
if good and epoch >= args.adv_delay:
|
||||
min_loss = epoch_loss
|
||||
|
||||
state = {
|
||||
|
@ -235,14 +238,14 @@ def gpu_worker(local_rank, args):
|
|||
'adv_model': adv_model.module.state_dict(),
|
||||
})
|
||||
ckpt_file = 'checkpoint.pth'
|
||||
best_file = 'best_model_{}.pth'
|
||||
state_file = 'state_{}.pth'
|
||||
torch.save(state, ckpt_file)
|
||||
del state
|
||||
|
||||
if is_best:
|
||||
shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
|
||||
#if os.path.isfile(best_file.format(epoch)):
|
||||
# os.remove(best_file.format(epoch))
|
||||
if good:
|
||||
shutil.copyfile(ckpt_file, state_file.format(epoch + 1))
|
||||
#if os.path.isfile(state_file.format(epoch)):
|
||||
# os.remove(state_file.format(epoch))
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
|
Loading…
Reference in a new issue