Save past best models by not overwriting them

This commit is contained in:
Yin Li 2019-12-12 16:51:47 -05:00
parent 946805c6be
commit 0533150194

View File

@ -144,13 +144,16 @@ def gpu_worker(local_rank, args):
'rng' : torch.get_rng_state(),
'min_loss': min_loss,
}
filename='checkpoint.pth'
ckpt_file = 'checkpoint.pth'
best_file = 'best_model_{}.pth'
torch.save(state, filename)
del state
if min_loss is None or val_loss < min_loss:
min_loss = val_loss
shutil.copyfile(filename, 'best_model.pth')
shutil.copyfile(filename, best_file.format(epoch + 1))
if os.path.isfile(best_file.format(epoch)):
os.remove(best_file.format(epoch))
destroy_process_group()