Change .pth to .pt following torch convention
This commit is contained in:
parent
c3877ae982
commit
c3d7456d0c
@ -70,8 +70,9 @@ follow [Customization](#customization).
|
|||||||
#### Files generated
|
#### Files generated
|
||||||
|
|
||||||
* `*.out`: job stdout and stderr
|
* `*.out`: job stdout and stderr
|
||||||
* `state_*.pth`: training state including the model parameters
|
* `state_{i}.pt`: training state after the i-th epoch including the
|
||||||
* `checkpoint.pth`: symlink to the latest state
|
model state
|
||||||
|
* `checkpoint.pt`: symlink to the latest state
|
||||||
* `runs/`: directories of tensorboard logs
|
* `runs/`: directories of tensorboard logs
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from .models import (narrow_like,
|
|||||||
from .utils import import_attr, load_model_state_dict
|
from .utils import import_attr, load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
ckpt_link = 'checkpoint.pth'
|
ckpt_link = 'checkpoint.pt'
|
||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
@ -283,11 +283,11 @@ def gpu_worker(local_rank, node, args):
|
|||||||
if args.adv:
|
if args.adv:
|
||||||
state['adv_model'] = adv_model.module.state_dict()
|
state['adv_model'] = adv_model.module.state_dict()
|
||||||
|
|
||||||
state_file = 'state_{}.pth'.format(epoch + 1)
|
state_file = 'state_{}.pt'.format(epoch + 1)
|
||||||
torch.save(state, state_file)
|
torch.save(state, state_file)
|
||||||
del state
|
del state
|
||||||
|
|
||||||
tmp_link = '{}.pth'.format(time.time())
|
tmp_link = '{}.pt'.format(time.time())
|
||||||
os.symlink(state_file, tmp_link) # workaround to overwrite
|
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||||
os.rename(tmp_link, ckpt_link)
|
os.rename(tmp_link, ckpt_link)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ m2m.py test \
|
|||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
|
--in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pth \
|
--load-state best_model.pt \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0 \
|
||||||
--cache
|
--cache
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ m2m.py test \
|
|||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
|
--in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pth \
|
--load-state best_model.pt \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0 \
|
||||||
--cache
|
--cache
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user