Change state saving to every epoch

This commit is contained in:
Yin Li 2020-03-07 18:03:19 -05:00
parent ccb323e6ee
commit 93d973b5c8

View File

@ -261,11 +261,7 @@ def gpu_worker(local_rank, node, args):
state_file = 'state_{}.pth' state_file = 'state_{}.pth'
torch.save(state, ckpt_file) torch.save(state, ckpt_file)
del state del state
if good:
shutil.copyfile(ckpt_file, state_file.format(epoch + 1)) shutil.copyfile(ckpt_file, state_file.format(epoch + 1))
#if os.path.isfile(state_file.format(epoch)):
# os.remove(state_file.format(epoch))
dist.destroy_process_group() dist.destroy_process_group()