Remove saving / loading optimizer & scheduler states

Conflicts:
	map2map/train.py
This commit is contained in:
Yin Li 2020-02-07 10:02:32 -05:00
parent 1cd34c2eed
commit 3fb9708575

View File

@ -161,12 +161,8 @@ def gpu_worker(local_rank, args):
args.start_epoch = state['epoch'] args.start_epoch = state['epoch']
args.adv_delay += args.start_epoch args.adv_delay += args.start_epoch
model.module.load_state_dict(state['model']) model.module.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler'])
if 'adv_model' in state and args.adv: if 'adv_model' in state and args.adv:
adv_model.module.load_state_dict(state['adv_model']) adv_model.module.load_state_dict(state['adv_model'])
adv_optimizer.load_state_dict(state['adv_optimizer'])
adv_scheduler.load_state_dict(state['adv_scheduler'])
torch.set_rng_state(state['rng'].cpu()) # move rng state back torch.set_rng_state(state['rng'].cpu()) # move rng state back
if args.rank == 0: if args.rank == 0:
min_loss = state['min_loss'] min_loss = state['min_loss']
@ -230,16 +226,12 @@ def gpu_worker(local_rank, args):
state = { state = {
'epoch': epoch + 1, 'epoch': epoch + 1,
'model': model.module.state_dict(), 'model': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'rng': torch.get_rng_state(), 'rng': torch.get_rng_state(),
'min_loss': min_loss, 'min_loss': min_loss,
} }
if args.adv: if args.adv:
state.update({ state.update({
'adv_model': adv_model.module.state_dict(), 'adv_model': adv_model.module.state_dict(),
'adv_optimizer': adv_optimizer.state_dict(),
'adv_scheduler': adv_scheduler.state_dict(),
}) })
ckpt_file = 'checkpoint.pth' ckpt_file = 'checkpoint.pth'
best_file = 'best_model_{}.pth' best_file = 'best_model_{}.pth'