Add auto loading of checkpoint

This commit is contained in:
Yin Li 2020-04-21 18:29:57 -04:00
parent c9f468c568
commit 996c0d3aed
6 changed files with 37 additions and 31 deletions

View File

@ -1,5 +1,7 @@
import argparse import argparse
from .train import ckpt_link
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -43,8 +45,10 @@ def add_common_args(parser):
help='model from .models') help='model from .models')
parser.add_argument('--criterion', default='MSELoss', type=str, parser.add_argument('--criterion', default='MSELoss', type=str,
help='model criterion from torch.nn') help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str, parser.add_argument('--load-state', default=ckpt_link, type=str,
help='path to load the states of model, optimizer, rng, etc.') help='path to load the states of model, optimizer, rng, etc. '
'Default is the checkpoint. '
'Start from scratch if the checkpoint does not exist')
parser.add_argument('--load-state-non-strict', action='store_false', parser.add_argument('--load-state-non-strict', action='store_false',
help='allow incompatible keys when loading model states', help='allow incompatible keys when loading model states',
dest='load_state_strict') dest='load_state_strict')

View File

@ -23,6 +23,9 @@ from .models import (narrow_like,
from .state import load_model_state_dict from .state import load_model_state_dict
ckpt_link = 'checkpoint.pth'
def node_worker(args): def node_worker(args):
if 'SLURM_STEP_NUM_NODES' in os.environ: if 'SLURM_STEP_NUM_NODES' in os.environ:
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES']) args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
@ -163,7 +166,31 @@ def gpu_worker(local_rank, node, args):
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer, adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
factor=0.1, patience=10, verbose=True) factor=0.1, patience=10, verbose=True)
if args.load_state: if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_scale)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_scale)
m.bias.data.fill_(0)
if args.init_weight_scale is not None:
model.apply(init_weights)
if args.adv:
adv_model.apply(init_weights)
start_epoch = 0
if rank == 0:
min_loss = None
else:
state = torch.load(args.load_state, map_location=device) state = torch.load(args.load_state, map_location=device)
start_epoch = state['epoch'] start_epoch = state['epoch']
@ -181,31 +208,11 @@ def gpu_worker(local_rank, node, args):
min_loss = state['min_loss'] min_loss = state['min_loss']
if args.adv and 'adv_model' not in state: if args.adv and 'adv_model' not in state:
min_loss = None # restarting with adversary wipes the record min_loss = None # restarting with adversary wipes the record
print('state at epoch {} loaded from {}'.format( print('state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state), flush=True) state['epoch'], args.load_state), flush=True)
del state del state
else:
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_scale)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_scale)
m.bias.data.fill_(0)
if args.init_weight_scale is not None:
model.apply(init_weights)
if args.adv:
adv_model.apply(init_weights)
start_epoch = 0
if rank == 0:
min_loss = None
torch.backends.cudnn.benchmark = True # NOTE: test perf torch.backends.cudnn.benchmark = True # NOTE: test perf
@ -248,8 +255,8 @@ def gpu_worker(local_rank, node, args):
except AttributeError: except AttributeError:
logger.close() # old pytorch logger.close() # old pytorch
good = min_loss is None or epoch_loss[0] < min_loss[0] if ((min_loss is None or epoch_loss[0] < min_loss[0])
if good and epoch >= args.adv_start: and epoch >= args.adv_start):
min_loss = epoch_loss min_loss = epoch_loss
state = { state = {
@ -265,7 +272,6 @@ def gpu_worker(local_rank, node, args):
torch.save(state, state_file) torch.save(state, state_file)
del state del state
ckpt_link = 'checkpoint.pth'
tmp_link = '{}.pth'.format(time.time()) tmp_link = '{}.pth'.format(time.time())
os.symlink(state_file, tmp_link) # workaround to overwrite os.symlink(state_file, tmp_link) # workaround to overwrite
os.rename(tmp_link, ckpt_link) os.rename(tmp_link, ckpt_link)

View File

@ -40,7 +40,6 @@ srun m2m.py train \
--lr 0.0001 --batches 1 --loader-workers 0 \ --lr 0.0001 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \ --epochs 1024 --seed $RANDOM \
--cache --div-data --cache --div-data
# --load-state checkpoint.pth \
date date

View File

@ -41,7 +41,6 @@ srun m2m.py train \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \ --epochs 1024 --seed $RANDOM \
--cache --div-data --cache --div-data
# --load-state checkpoint.pth \
date date

View File

@ -43,7 +43,6 @@ srun m2m.py train \
--cache --div-data --cache --div-data
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \ # --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \ # --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
# --load-state checkpoint.pth \
date date

View File

@ -41,7 +41,6 @@ srun m2m.py train \
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \ --lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
--epochs 1024 --seed $RANDOM \ --epochs 1024 --seed $RANDOM \
--cache --div-data --cache --div-data
# --load-state checkpoint.pth \
date date