Remove saving / loading optimizer & scheduler states
This commit is contained in:
parent
890a459363
commit
23b903f81b
@ -153,15 +153,11 @@ def gpu_worker(local_rank, node, args):
|
|||||||
start_epoch = state['epoch']
|
start_epoch = state['epoch']
|
||||||
|
|
||||||
model.module.load_state_dict(state['model'])
|
model.module.load_state_dict(state['model'])
|
||||||
optimizer.load_state_dict(state['optimizer'])
|
|
||||||
scheduler.load_state_dict(state['scheduler'])
|
|
||||||
|
|
||||||
if 'adv_model' in state and args.adv:
|
if 'adv_model' in state and args.adv:
|
||||||
args.adv_epoch = state['adv_epoch']
|
args.adv_epoch = state['adv_epoch']
|
||||||
|
|
||||||
adv_model.module.load_state_dict(state['adv_model'])
|
adv_model.module.load_state_dict(state['adv_model'])
|
||||||
adv_optimizer.load_state_dict(state['adv_optimizer'])
|
|
||||||
adv_scheduler.load_state_dict(state['adv_scheduler'])
|
|
||||||
elif 'adv_model' not in state and args.adv:
|
elif 'adv_model' not in state and args.adv:
|
||||||
args.adv_epoch = start_epoch
|
args.adv_epoch = start_epoch
|
||||||
|
|
||||||
@ -231,8 +227,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
state = {
|
state = {
|
||||||
'epoch': epoch + 1,
|
'epoch': epoch + 1,
|
||||||
'model': model.module.state_dict(),
|
'model': model.module.state_dict(),
|
||||||
'optimizer': optimizer.state_dict(),
|
|
||||||
'scheduler': scheduler.state_dict(),
|
|
||||||
'rng': torch.get_rng_state(),
|
'rng': torch.get_rng_state(),
|
||||||
'min_loss': min_loss,
|
'min_loss': min_loss,
|
||||||
}
|
}
|
||||||
@ -240,8 +234,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
state.update({
|
state.update({
|
||||||
'adv_epoch': args.adv_epoch,
|
'adv_epoch': args.adv_epoch,
|
||||||
'adv_model': adv_model.module.state_dict(),
|
'adv_model': adv_model.module.state_dict(),
|
||||||
'adv_optimizer': adv_optimizer.state_dict(),
|
|
||||||
'adv_scheduler': adv_scheduler.state_dict(),
|
|
||||||
})
|
})
|
||||||
ckpt_file = 'checkpoint.pth'
|
ckpt_file = 'checkpoint.pth'
|
||||||
best_file = 'best_model_{}.pth'
|
best_file = 'best_model_{}.pth'
|
||||||
|
Loading…
Reference in New Issue
Block a user