Change .pth to .pt following torch convention

This commit is contained in:
Yin Li 2020-06-20 18:31:09 -04:00
parent c3877ae982
commit c3d7456d0c
4 changed files with 8 additions and 7 deletions

View File

@ -70,8 +70,9 @@ follow [Customization](#customization).
#### Files generated #### Files generated
* `*.out`: job stdout and stderr * `*.out`: job stdout and stderr
* `state_*.pth`: training state including the model parameters * `state_{i}.pt`: training state after the i-th epoch including the
* `checkpoint.pth`: symlink to the latest state model state
* `checkpoint.pt`: symlink to the latest state
* `runs/`: directories of tensorboard logs * `runs/`: directories of tensorboard logs

View File

@ -24,7 +24,7 @@ from .models import (narrow_like,
from .utils import import_attr, load_model_state_dict from .utils import import_attr, load_model_state_dict
ckpt_link = 'checkpoint.pth' ckpt_link = 'checkpoint.pt'
def node_worker(args): def node_worker(args):
@ -283,11 +283,11 @@ def gpu_worker(local_rank, node, args):
if args.adv: if args.adv:
state['adv_model'] = adv_model.module.state_dict() state['adv_model'] = adv_model.module.state_dict()
state_file = 'state_{}.pth'.format(epoch + 1) state_file = 'state_{}.pt'.format(epoch + 1)
torch.save(state, state_file) torch.save(state, state_file)
del state del state
tmp_link = '{}.pth'.format(time.time()) tmp_link = '{}.pt'.format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link) os.rename(tmp_link, ckpt_link)

View File

@ -37,7 +37,7 @@ m2m.py test \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \ --in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
--model VNet \ --model VNet \
--load-state best_model.pth \ --load-state best_model.pt \
--batches 1 --loader-workers 0 \ --batches 1 --loader-workers 0 \
--cache --cache

View File

@ -37,7 +37,7 @@ m2m.py test \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \ --in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
--model VNet \ --model VNet \
--load-state best_model.pth \ --load-state best_model.pt \
--batches 1 --loader-workers 0 \ --batches 1 --loader-workers 0 \
--cache --cache