Add auto loading of checkpoint

This commit is contained in:
Yin Li 2020-04-21 18:29:57 -04:00
parent c9f468c568
commit 996c0d3aed
6 changed files with 37 additions and 31 deletions

View file

@ -1,5 +1,7 @@
import argparse
from .train import ckpt_link
def get_args():
parser = argparse.ArgumentParser(
@ -43,8 +45,10 @@ def add_common_args(parser):
help='model from .models')
parser.add_argument('--criterion', default='MSELoss', type=str,
help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str,
help='path to load the states of model, optimizer, rng, etc.')
parser.add_argument('--load-state', default=ckpt_link, type=str,
help='path to load the states of model, optimizer, rng, etc. '
'Default is the checkpoint. '
'Start from scratch if the checkpoint does not exist')
parser.add_argument('--load-state-non-strict', action='store_false',
help='allow incompatible keys when loading model states',
dest='load_state_strict')

View file

@ -23,6 +23,9 @@ from .models import (narrow_like,
from .state import load_model_state_dict
ckpt_link = 'checkpoint.pth'
def node_worker(args):
if 'SLURM_STEP_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
@ -163,7 +166,31 @@ def gpu_worker(local_rank, node, args):
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
factor=0.1, patience=10, verbose=True)
if args.load_state:
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_scale)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_scale)
m.bias.data.fill_(0)
if args.init_weight_scale is not None:
model.apply(init_weights)
if args.adv:
adv_model.apply(init_weights)
start_epoch = 0
if rank == 0:
min_loss = None
else:
state = torch.load(args.load_state, map_location=device)
start_epoch = state['epoch']
@ -181,31 +208,11 @@ def gpu_worker(local_rank, node, args):
min_loss = state['min_loss']
if args.adv and 'adv_model' not in state:
min_loss = None # restarting with adversary wipes the record
print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True)
del state
else:
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_scale)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_scale)
m.bias.data.fill_(0)
if args.init_weight_scale is not None:
model.apply(init_weights)
if args.adv:
adv_model.apply(init_weights)
start_epoch = 0
if rank == 0:
min_loss = None
torch.backends.cudnn.benchmark = True # NOTE: test perf
@ -248,8 +255,8 @@ def gpu_worker(local_rank, node, args):
except AttributeError:
logger.close() # old pytorch
good = min_loss is None or epoch_loss[0] < min_loss[0]
if good and epoch >= args.adv_start:
if ((min_loss is None or epoch_loss[0] < min_loss[0])
and epoch >= args.adv_start):
min_loss = epoch_loss
state = {
@ -265,7 +272,6 @@ def gpu_worker(local_rank, node, args):
torch.save(state, state_file)
del state
ckpt_link = 'checkpoint.pth'
tmp_link = '{}.pth'.format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link)

View file

@ -40,7 +40,6 @@ srun m2m.py train \
--lr 0.0001 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \
--cache --div-data
# --load-state checkpoint.pth \
date

View file

@ -41,7 +41,6 @@ srun m2m.py train \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \
--cache --div-data
# --load-state checkpoint.pth \
date

View file

@ -43,7 +43,6 @@ srun m2m.py train \
--cache --div-data
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
# --load-state checkpoint.pth \
date

View file

@ -41,7 +41,6 @@ srun m2m.py train \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \
--cache --div-data
# --load-state checkpoint.pth \
date