Add checkpoint symlink to state file
This commit is contained in:
parent
d01d0cee83
commit
01a60cc0c7
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
import sys
|
import sys
|
||||||
@ -258,11 +257,15 @@ def gpu_worker(local_rank, node, args):
|
|||||||
}
|
}
|
||||||
if args.adv:
|
if args.adv:
|
||||||
state['adv_model'] = adv_model.module.state_dict()
|
state['adv_model'] = adv_model.module.state_dict()
|
||||||
ckpt_file = 'checkpoint.pth'
|
|
||||||
state_file = 'state_{}.pth'
|
state_file = 'state_{}.pth'.format(epoch + 1)
|
||||||
torch.save(state, ckpt_file)
|
torch.save(state, state_file)
|
||||||
del state
|
del state
|
||||||
shutil.copyfile(ckpt_file, state_file.format(epoch + 1))
|
|
||||||
|
ckpt_link = 'checkpoint.pth'
|
||||||
|
tmp_link = '{}.pth'.format(time.time())
|
||||||
|
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||||
|
os.rename(tmp_link, ckpt_link)
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user