Remove saving / loading optimizer & scheduler states

This commit is contained in:
Yin Li 2020-02-07 10:02:32 -05:00
parent 890a459363
commit 23b903f81b

View File

@ -153,15 +153,11 @@ def gpu_worker(local_rank, node, args):
start_epoch = state['epoch'] start_epoch = state['epoch']
model.module.load_state_dict(state['model']) model.module.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler'])
if 'adv_model' in state and args.adv: if 'adv_model' in state and args.adv:
args.adv_epoch = state['adv_epoch'] args.adv_epoch = state['adv_epoch']
adv_model.module.load_state_dict(state['adv_model']) adv_model.module.load_state_dict(state['adv_model'])
adv_optimizer.load_state_dict(state['adv_optimizer'])
adv_scheduler.load_state_dict(state['adv_scheduler'])
elif 'adv_model' not in state and args.adv: elif 'adv_model' not in state and args.adv:
args.adv_epoch = start_epoch args.adv_epoch = start_epoch
@ -231,8 +227,6 @@ def gpu_worker(local_rank, node, args):
state = { state = {
'epoch': epoch + 1, 'epoch': epoch + 1,
'model': model.module.state_dict(), 'model': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'rng': torch.get_rng_state(), 'rng': torch.get_rng_state(),
'min_loss': min_loss, 'min_loss': min_loss,
} }
@ -240,8 +234,6 @@ def gpu_worker(local_rank, node, args):
state.update({ state.update({
'adv_epoch': args.adv_epoch, 'adv_epoch': args.adv_epoch,
'adv_model': adv_model.module.state_dict(), 'adv_model': adv_model.module.state_dict(),
'adv_optimizer': adv_optimizer.state_dict(),
'adv_scheduler': adv_scheduler.state_dict(),
}) })
ckpt_file = 'checkpoint.pth' ckpt_file = 'checkpoint.pth'
best_file = 'best_model_{}.pth' best_file = 'best_model_{}.pth'