Add checkpoint symlink to state file

This commit is contained in:
Yin Li 2020-04-16 17:50:08 -04:00
parent d01d0cee83
commit 01a60cc0c7

View File

@ -1,5 +1,4 @@
import os
import shutil
import socket
import time
import sys
@ -258,11 +257,15 @@ def gpu_worker(local_rank, node, args):
}
if args.adv:
state['adv_model'] = adv_model.module.state_dict()
ckpt_file = 'checkpoint.pth'
state_file = 'state_{}.pth'
torch.save(state, ckpt_file)
state_file = 'state_{}.pth'.format(epoch + 1)
torch.save(state, state_file)
del state
shutil.copyfile(ckpt_file, state_file.format(epoch + 1))
ckpt_link = 'checkpoint.pth'
tmp_link = '{}.pth'.format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link)
dist.destroy_process_group()