Remove saving / loading optimizer & scheduler states
This commit is contained in:
parent
890a459363
commit
23b903f81b
@ -153,15 +153,11 @@ def gpu_worker(local_rank, node, args):
|
||||
start_epoch = state['epoch']
|
||||
|
||||
model.module.load_state_dict(state['model'])
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
scheduler.load_state_dict(state['scheduler'])
|
||||
|
||||
if 'adv_model' in state and args.adv:
|
||||
args.adv_epoch = state['adv_epoch']
|
||||
|
||||
adv_model.module.load_state_dict(state['adv_model'])
|
||||
adv_optimizer.load_state_dict(state['adv_optimizer'])
|
||||
adv_scheduler.load_state_dict(state['adv_scheduler'])
|
||||
elif 'adv_model' not in state and args.adv:
|
||||
args.adv_epoch = start_epoch
|
||||
|
||||
@ -231,8 +227,6 @@ def gpu_worker(local_rank, node, args):
|
||||
state = {
|
||||
'epoch': epoch + 1,
|
||||
'model': model.module.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
'rng': torch.get_rng_state(),
|
||||
'min_loss': min_loss,
|
||||
}
|
||||
@ -240,8 +234,6 @@ def gpu_worker(local_rank, node, args):
|
||||
state.update({
|
||||
'adv_epoch': args.adv_epoch,
|
||||
'adv_model': adv_model.module.state_dict(),
|
||||
'adv_optimizer': adv_optimizer.state_dict(),
|
||||
'adv_scheduler': adv_scheduler.state_dict(),
|
||||
})
|
||||
ckpt_file = 'checkpoint.pth'
|
||||
best_file = 'best_model_{}.pth'
|
||||
|
Loading…
Reference in New Issue
Block a user