This commit is contained in:
Yin Li 2019-12-12 19:26:57 -05:00
parent 6d021ec949
commit d03bcb59a1

View File

@ -146,12 +146,12 @@ def gpu_worker(local_rank, args):
}
ckpt_file = 'checkpoint.pth'
best_file = 'best_model_{}.pth'
torch.save(state, filename)
torch.save(state, ckpt_file)
del state
if min_loss is None or val_loss < min_loss:
min_loss = val_loss
shutil.copyfile(filename, best_file.format(epoch + 1))
shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
if os.path.isfile(best_file.format(epoch)):
os.remove(best_file.format(epoch))