From c3d7456d0c5b462b96f88d439ae14222fe3720e6 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sat, 20 Jun 2020 18:31:09 -0400 Subject: [PATCH] Change .pth to .pt following torch convention --- README.md | 5 +++-- map2map/train.py | 6 +++--- scripts/dis2dis-test.slurm | 2 +- scripts/vel2vel-test.slurm | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 78321e4..65129ca 100644 --- a/README.md +++ b/README.md @@ -70,8 +70,9 @@ follow [Customization](#customization). #### Files generated * `*.out`: job stdout and stderr -* `state_*.pth`: training state including the model parameters -* `checkpoint.pth`: symlink to the latest state +* `state_{i}.pt`: training state after the i-th epoch including the + model state +* `checkpoint.pt`: symlink to the latest state * `runs/`: directories of tensorboard logs diff --git a/map2map/train.py b/map2map/train.py index f87dac7..3d44fdf 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -24,7 +24,7 @@ from .models import (narrow_like, from .utils import import_attr, load_model_state_dict -ckpt_link = 'checkpoint.pth' +ckpt_link = 'checkpoint.pt' def node_worker(args): @@ -283,11 +283,11 @@ def gpu_worker(local_rank, node, args): if args.adv: state['adv_model'] = adv_model.module.state_dict() - state_file = 'state_{}.pth'.format(epoch + 1) + state_file = 'state_{}.pt'.format(epoch + 1) torch.save(state, state_file) del state - tmp_link = '{}.pth'.format(time.time()) + tmp_link = '{}.pt'.format(time.time()) os.symlink(state_file, tmp_link) # workaround to overwrite os.rename(tmp_link, ckpt_link) diff --git a/scripts/dis2dis-test.slurm b/scripts/dis2dis-test.slurm index 345979f..0e035c6 100644 --- a/scripts/dis2dis-test.slurm +++ b/scripts/dis2dis-test.slurm @@ -37,7 +37,7 @@ m2m.py test \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --in-norms cosmology.dis --tgt-norms cosmology.dis --crop 256 --pad 20 \ --model VNet \ - --load-state best_model.pth \ + --load-state best_model.pt \ --batches 1 --loader-workers 0 \ --cache diff --git a/scripts/vel2vel-test.slurm b/scripts/vel2vel-test.slurm index 5ea210a..7c9c3e9 100644 --- a/scripts/vel2vel-test.slurm +++ b/scripts/vel2vel-test.slurm @@ -37,7 +37,7 @@ m2m.py test \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --in-norms cosmology.vel --tgt-norms cosmology.vel --crop 256 --pad 20 \ --model VNet \ - --load-state best_model.pth \ + --load-state best_model.pt \ --batches 1 --loader-workers 0 \ --cache