Fix bug from 0533150
This commit is contained in:
parent
6d021ec949
commit
d03bcb59a1
@ -146,12 +146,12 @@ def gpu_worker(local_rank, args):
|
||||
}
|
||||
ckpt_file = 'checkpoint.pth'
|
||||
best_file = 'best_model_{}.pth'
|
||||
torch.save(state, filename)
|
||||
torch.save(state, ckpt_file)
|
||||
del state
|
||||
|
||||
if min_loss is None or val_loss < min_loss:
|
||||
min_loss = val_loss
|
||||
shutil.copyfile(filename, best_file.format(epoch + 1))
|
||||
shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
|
||||
if os.path.isfile(best_file.format(epoch)):
|
||||
os.remove(best_file.format(epoch))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user