Remove saving / loading optimizer & scheduler states
Conflicts: map2map/train.py
This commit is contained in:
parent
1cd34c2eed
commit
3fb9708575
@ -161,12 +161,8 @@ def gpu_worker(local_rank, args):
|
|||||||
args.start_epoch = state['epoch']
|
args.start_epoch = state['epoch']
|
||||||
args.adv_delay += args.start_epoch
|
args.adv_delay += args.start_epoch
|
||||||
model.module.load_state_dict(state['model'])
|
model.module.load_state_dict(state['model'])
|
||||||
optimizer.load_state_dict(state['optimizer'])
|
|
||||||
scheduler.load_state_dict(state['scheduler'])
|
|
||||||
if 'adv_model' in state and args.adv:
|
if 'adv_model' in state and args.adv:
|
||||||
adv_model.module.load_state_dict(state['adv_model'])
|
adv_model.module.load_state_dict(state['adv_model'])
|
||||||
adv_optimizer.load_state_dict(state['adv_optimizer'])
|
|
||||||
adv_scheduler.load_state_dict(state['adv_scheduler'])
|
|
||||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
min_loss = state['min_loss']
|
min_loss = state['min_loss']
|
||||||
@ -230,16 +226,12 @@ def gpu_worker(local_rank, args):
|
|||||||
state = {
|
state = {
|
||||||
'epoch': epoch + 1,
|
'epoch': epoch + 1,
|
||||||
'model': model.module.state_dict(),
|
'model': model.module.state_dict(),
|
||||||
'optimizer': optimizer.state_dict(),
|
|
||||||
'scheduler': scheduler.state_dict(),
|
|
||||||
'rng': torch.get_rng_state(),
|
'rng': torch.get_rng_state(),
|
||||||
'min_loss': min_loss,
|
'min_loss': min_loss,
|
||||||
}
|
}
|
||||||
if args.adv:
|
if args.adv:
|
||||||
state.update({
|
state.update({
|
||||||
'adv_model': adv_model.module.state_dict(),
|
'adv_model': adv_model.module.state_dict(),
|
||||||
'adv_optimizer': adv_optimizer.state_dict(),
|
|
||||||
'adv_scheduler': adv_scheduler.state_dict(),
|
|
||||||
})
|
})
|
||||||
ckpt_file = 'checkpoint.pth'
|
ckpt_file = 'checkpoint.pth'
|
||||||
best_file = 'best_model_{}.pth'
|
best_file = 'best_model_{}.pth'
|
||||||
|
Loading…
Reference in New Issue
Block a user