Add auto loading of checkpoint
This commit is contained in:
parent
c9f468c568
commit
996c0d3aed
6 changed files with 37 additions and 31 deletions
|
@ -1,5 +1,7 @@
|
|||
import argparse
|
||||
|
||||
from .train import ckpt_link
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -43,8 +45,10 @@ def add_common_args(parser):
|
|||
help='model from .models')
|
||||
parser.add_argument('--criterion', default='MSELoss', type=str,
|
||||
help='model criterion from torch.nn')
|
||||
parser.add_argument('--load-state', default='', type=str,
|
||||
help='path to load the states of model, optimizer, rng, etc.')
|
||||
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
||||
help='path to load the states of model, optimizer, rng, etc. '
|
||||
'Default is the checkpoint. '
|
||||
'Start from scratch if the checkpoint does not exist')
|
||||
parser.add_argument('--load-state-non-strict', action='store_false',
|
||||
help='allow incompatible keys when loading model states',
|
||||
dest='load_state_strict')
|
||||
|
|
|
@ -23,6 +23,9 @@ from .models import (narrow_like,
|
|||
from .state import load_model_state_dict
|
||||
|
||||
|
||||
ckpt_link = 'checkpoint.pth'
|
||||
|
||||
|
||||
def node_worker(args):
|
||||
if 'SLURM_STEP_NUM_NODES' in os.environ:
|
||||
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
||||
|
@ -163,7 +166,31 @@ def gpu_worker(local_rank, node, args):
|
|||
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
|
||||
if args.load_state:
|
||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||
or not args.load_state):
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
m.weight.data.normal_(0.0, args.init_weight_scale)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||
if m.affine:
|
||||
# NOTE: dispersion from DCGAN, why?
|
||||
m.weight.data.normal_(1.0, args.init_weight_scale)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
if args.init_weight_scale is not None:
|
||||
model.apply(init_weights)
|
||||
|
||||
if args.adv:
|
||||
adv_model.apply(init_weights)
|
||||
|
||||
start_epoch = 0
|
||||
|
||||
if rank == 0:
|
||||
min_loss = None
|
||||
else:
|
||||
state = torch.load(args.load_state, map_location=device)
|
||||
|
||||
start_epoch = state['epoch']
|
||||
|
@ -181,31 +208,11 @@ def gpu_worker(local_rank, node, args):
|
|||
min_loss = state['min_loss']
|
||||
if args.adv and 'adv_model' not in state:
|
||||
min_loss = None # restarting with adversary wipes the record
|
||||
|
||||
print('state at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state), flush=True)
|
||||
|
||||
del state
|
||||
else:
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
m.weight.data.normal_(0.0, args.init_weight_scale)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||
if m.affine:
|
||||
# NOTE: dispersion from DCGAN, why?
|
||||
m.weight.data.normal_(1.0, args.init_weight_scale)
|
||||
m.bias.data.fill_(0)
|
||||
if args.init_weight_scale is not None:
|
||||
model.apply(init_weights)
|
||||
if args.adv:
|
||||
adv_model.apply(init_weights)
|
||||
|
||||
start_epoch = 0
|
||||
|
||||
if rank == 0:
|
||||
min_loss = None
|
||||
|
||||
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
||||
|
||||
|
@ -248,8 +255,8 @@ def gpu_worker(local_rank, node, args):
|
|||
except AttributeError:
|
||||
logger.close() # old pytorch
|
||||
|
||||
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||
if good and epoch >= args.adv_start:
|
||||
if ((min_loss is None or epoch_loss[0] < min_loss[0])
|
||||
and epoch >= args.adv_start):
|
||||
min_loss = epoch_loss
|
||||
|
||||
state = {
|
||||
|
@ -265,7 +272,6 @@ def gpu_worker(local_rank, node, args):
|
|||
torch.save(state, state_file)
|
||||
del state
|
||||
|
||||
ckpt_link = 'checkpoint.pth'
|
||||
tmp_link = '{}.pth'.format(time.time())
|
||||
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||
os.rename(tmp_link, ckpt_link)
|
||||
|
|
|
@ -40,7 +40,6 @@ srun m2m.py train \
|
|||
--lr 0.0001 --batches 1 --loader-workers 0 \
|
||||
--epochs 1024 --seed $RANDOM \
|
||||
--cache --div-data
|
||||
# --load-state checkpoint.pth \
|
||||
|
||||
|
||||
date
|
||||
|
|
|
@ -41,7 +41,6 @@ srun m2m.py train \
|
|||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||
--epochs 1024 --seed $RANDOM \
|
||||
--cache --div-data
|
||||
# --load-state checkpoint.pth \
|
||||
|
||||
|
||||
date
|
||||
|
|
|
@ -43,7 +43,6 @@ srun m2m.py train \
|
|||
--cache --div-data
|
||||
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
|
||||
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
|
||||
# --load-state checkpoint.pth \
|
||||
|
||||
|
||||
date
|
||||
|
|
|
@ -41,7 +41,6 @@ srun m2m.py train \
|
|||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||
--epochs 1024 --seed $RANDOM \
|
||||
--cache --div-data
|
||||
# --load-state checkpoint.pth \
|
||||
|
||||
|
||||
date
|
||||
|
|
Loading…
Reference in a new issue