Remove saving / loading optimizer & scheduler states
Conflicts: map2map/train.py
This commit is contained in:
parent
1cd34c2eed
commit
3fb9708575
@ -161,12 +161,8 @@ def gpu_worker(local_rank, args):
|
||||
args.start_epoch = state['epoch']
|
||||
args.adv_delay += args.start_epoch
|
||||
model.module.load_state_dict(state['model'])
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
scheduler.load_state_dict(state['scheduler'])
|
||||
if 'adv_model' in state and args.adv:
|
||||
adv_model.module.load_state_dict(state['adv_model'])
|
||||
adv_optimizer.load_state_dict(state['adv_optimizer'])
|
||||
adv_scheduler.load_state_dict(state['adv_scheduler'])
|
||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||
if args.rank == 0:
|
||||
min_loss = state['min_loss']
|
||||
@ -230,16 +226,12 @@ def gpu_worker(local_rank, args):
|
||||
state = {
|
||||
'epoch': epoch + 1,
|
||||
'model': model.module.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
'rng': torch.get_rng_state(),
|
||||
'min_loss': min_loss,
|
||||
}
|
||||
if args.adv:
|
||||
state.update({
|
||||
'adv_model': adv_model.module.state_dict(),
|
||||
'adv_optimizer': adv_optimizer.state_dict(),
|
||||
'adv_scheduler': adv_scheduler.state_dict(),
|
||||
})
|
||||
ckpt_file = 'checkpoint.pth'
|
||||
best_file = 'best_model_{}.pth'
|
||||
|
Loading…
Reference in New Issue
Block a user