diff --git a/map2map/train.py b/map2map/train.py index ce83a4c..dfcfe88 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -1,5 +1,4 @@ import os -import shutil import socket import time import sys @@ -258,11 +257,15 @@ def gpu_worker(local_rank, node, args): } if args.adv: state['adv_model'] = adv_model.module.state_dict() - ckpt_file = 'checkpoint.pth' - state_file = 'state_{}.pth' - torch.save(state, ckpt_file) + + state_file = 'state_{}.pth'.format(epoch + 1) + torch.save(state, state_file) del state - shutil.copyfile(ckpt_file, state_file.format(epoch + 1)) + + ckpt_link = 'checkpoint.pth' + tmp_link = '{}.pth'.format(time.time()) + os.symlink(state_file, tmp_link) # workaround to overwrite + os.rename(tmp_link, ckpt_link) dist.destroy_process_group()