Save past best models by not overwriting them
This commit is contained in:
parent
946805c6be
commit
0533150194
@ -144,13 +144,16 @@ def gpu_worker(local_rank, args):
|
|||||||
'rng' : torch.get_rng_state(),
|
'rng' : torch.get_rng_state(),
|
||||||
'min_loss': min_loss,
|
'min_loss': min_loss,
|
||||||
}
|
}
|
||||||
filename='checkpoint.pth'
|
ckpt_file = 'checkpoint.pth'
|
||||||
|
best_file = 'best_model_{}.pth'
|
||||||
torch.save(state, filename)
|
torch.save(state, filename)
|
||||||
del state
|
del state
|
||||||
|
|
||||||
if min_loss is None or val_loss < min_loss:
|
if min_loss is None or val_loss < min_loss:
|
||||||
min_loss = val_loss
|
min_loss = val_loss
|
||||||
shutil.copyfile(filename, 'best_model.pth')
|
shutil.copyfile(filename, best_file.format(epoch + 1))
|
||||||
|
if os.path.isfile(best_file.format(epoch)):
|
||||||
|
os.remove(best_file.format(epoch))
|
||||||
|
|
||||||
destroy_process_group()
|
destroy_process_group()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user