Add non-strict model state loading
This commit is contained in:
parent
f831afbccf
commit
09493ad4ec
@ -38,6 +38,9 @@ def add_common_args(parser):
|
|||||||
help='model criterion from torch.nn')
|
help='model criterion from torch.nn')
|
||||||
parser.add_argument('--load-state', default='', type=str,
|
parser.add_argument('--load-state', default='', type=str,
|
||||||
help='path to load the states of model, optimizer, rng, etc.')
|
help='path to load the states of model, optimizer, rng, etc.')
|
||||||
|
parser.add_argument('--load-state-non-strict', action='store_false',
|
||||||
|
help='allow incompatible keys when loading model states',
|
||||||
|
dest='load_state_strict')
|
||||||
|
|
||||||
parser.add_argument('--batches', default=1, type=int,
|
parser.add_argument('--batches', default=1, type=int,
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
|
9
map2map/state.py
Normal file
9
map2map/state.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
import warnings
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_state_dict(model, state_dict, strict=True):
|
||||||
|
bad_keys = model.load_state_dict(state_dict, strict)
|
||||||
|
|
||||||
|
if bad_keys.missing_keys or bad_keys.unexpected_keys:
|
||||||
|
warnings.warn(pprint(repr(bad_keys)))
|
@ -6,6 +6,7 @@ from torch.utils.data import DataLoader
|
|||||||
from .data import FieldDataset
|
from .data import FieldDataset
|
||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_like
|
from .models import narrow_like
|
||||||
|
from .state import load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
def test(args):
|
def test(args):
|
||||||
@ -39,7 +40,7 @@ def test(args):
|
|||||||
|
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
state = torch.load(args.load_state, map_location=device)
|
state = torch.load(args.load_state, map_location=device)
|
||||||
model.load_state_dict(state['model'])
|
load_model_state_dict(model, state['model'], strict=args.load_state_strict)
|
||||||
print('model state at epoch {} loaded from {}'.format(
|
print('model state at epoch {} loaded from {}'.format(
|
||||||
state['epoch'], args.load_state))
|
state['epoch'], args.load_state))
|
||||||
del state
|
del state
|
||||||
|
@ -15,6 +15,7 @@ from .data.figures import fig3d
|
|||||||
from . import models
|
from . import models
|
||||||
from .models import narrow_like
|
from .models import narrow_like
|
||||||
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||||
|
from .state import load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
@ -161,9 +162,11 @@ def gpu_worker(local_rank, args):
|
|||||||
state = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=args.device)
|
||||||
args.start_epoch = state['epoch']
|
args.start_epoch = state['epoch']
|
||||||
args.adv_delay += args.start_epoch
|
args.adv_delay += args.start_epoch
|
||||||
model.module.load_state_dict(state['model'])
|
load_model_state_dict(model.module, state['model'],
|
||||||
|
strict=args.load_state_strict)
|
||||||
if 'adv_model' in state and args.adv:
|
if 'adv_model' in state and args.adv:
|
||||||
adv_model.module.load_state_dict(state['adv_model'])
|
load_model_state_dict(adv_model.module, state['adv_model'],
|
||||||
|
strict=args.load_state_strict)
|
||||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
min_loss = state['min_loss']
|
min_loss = state['min_loss']
|
||||||
@ -220,8 +223,8 @@ def gpu_worker(local_rank, args):
|
|||||||
print(end='', flush=True)
|
print(end='', flush=True)
|
||||||
args.logger.close()
|
args.logger.close()
|
||||||
|
|
||||||
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
|
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||||
if is_best and epoch >= args.adv_delay:
|
if good and epoch >= args.adv_delay:
|
||||||
min_loss = epoch_loss
|
min_loss = epoch_loss
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
@ -235,14 +238,14 @@ def gpu_worker(local_rank, args):
|
|||||||
'adv_model': adv_model.module.state_dict(),
|
'adv_model': adv_model.module.state_dict(),
|
||||||
})
|
})
|
||||||
ckpt_file = 'checkpoint.pth'
|
ckpt_file = 'checkpoint.pth'
|
||||||
best_file = 'best_model_{}.pth'
|
state_file = 'state_{}.pth'
|
||||||
torch.save(state, ckpt_file)
|
torch.save(state, ckpt_file)
|
||||||
del state
|
del state
|
||||||
|
|
||||||
if is_best:
|
if good:
|
||||||
shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
|
shutil.copyfile(ckpt_file, state_file.format(epoch + 1))
|
||||||
#if os.path.isfile(best_file.format(epoch)):
|
#if os.path.isfile(state_file.format(epoch)):
|
||||||
# os.remove(best_file.format(epoch))
|
# os.remove(state_file.format(epoch))
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user