Add auto loading of checkpoint
This commit is contained in:
parent
c9f468c568
commit
996c0d3aed
@ -1,5 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
from .train import ckpt_link
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -43,8 +45,10 @@ def add_common_args(parser):
|
|||||||
help='model from .models')
|
help='model from .models')
|
||||||
parser.add_argument('--criterion', default='MSELoss', type=str,
|
parser.add_argument('--criterion', default='MSELoss', type=str,
|
||||||
help='model criterion from torch.nn')
|
help='model criterion from torch.nn')
|
||||||
parser.add_argument('--load-state', default='', type=str,
|
parser.add_argument('--load-state', default=ckpt_link, type=str,
|
||||||
help='path to load the states of model, optimizer, rng, etc.')
|
help='path to load the states of model, optimizer, rng, etc. '
|
||||||
|
'Default is the checkpoint. '
|
||||||
|
'Start from scratch if the checkpoint does not exist')
|
||||||
parser.add_argument('--load-state-non-strict', action='store_false',
|
parser.add_argument('--load-state-non-strict', action='store_false',
|
||||||
help='allow incompatible keys when loading model states',
|
help='allow incompatible keys when loading model states',
|
||||||
dest='load_state_strict')
|
dest='load_state_strict')
|
||||||
|
@ -23,6 +23,9 @@ from .models import (narrow_like,
|
|||||||
from .state import load_model_state_dict
|
from .state import load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
ckpt_link = 'checkpoint.pth'
|
||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
if 'SLURM_STEP_NUM_NODES' in os.environ:
|
if 'SLURM_STEP_NUM_NODES' in os.environ:
|
||||||
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
args.nodes = int(os.environ['SLURM_STEP_NUM_NODES'])
|
||||||
@ -163,7 +166,31 @@ def gpu_worker(local_rank, node, args):
|
|||||||
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
||||||
factor=0.1, patience=10, verbose=True)
|
factor=0.1, patience=10, verbose=True)
|
||||||
|
|
||||||
if args.load_state:
|
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||||
|
or not args.load_state):
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||||
|
m.weight.data.normal_(0.0, args.init_weight_scale)
|
||||||
|
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||||
|
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||||
|
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||||
|
if m.affine:
|
||||||
|
# NOTE: dispersion from DCGAN, why?
|
||||||
|
m.weight.data.normal_(1.0, args.init_weight_scale)
|
||||||
|
m.bias.data.fill_(0)
|
||||||
|
|
||||||
|
if args.init_weight_scale is not None:
|
||||||
|
model.apply(init_weights)
|
||||||
|
|
||||||
|
if args.adv:
|
||||||
|
adv_model.apply(init_weights)
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
min_loss = None
|
||||||
|
else:
|
||||||
state = torch.load(args.load_state, map_location=device)
|
state = torch.load(args.load_state, map_location=device)
|
||||||
|
|
||||||
start_epoch = state['epoch']
|
start_epoch = state['epoch']
|
||||||
@ -181,31 +208,11 @@ def gpu_worker(local_rank, node, args):
|
|||||||
min_loss = state['min_loss']
|
min_loss = state['min_loss']
|
||||||
if args.adv and 'adv_model' not in state:
|
if args.adv and 'adv_model' not in state:
|
||||||
min_loss = None # restarting with adversary wipes the record
|
min_loss = None # restarting with adversary wipes the record
|
||||||
|
|
||||||
print('state at epoch {} loaded from {}'.format(
|
print('state at epoch {} loaded from {}'.format(
|
||||||
state['epoch'], args.load_state), flush=True)
|
state['epoch'], args.load_state), flush=True)
|
||||||
|
|
||||||
del state
|
del state
|
||||||
else:
|
|
||||||
def init_weights(m):
|
|
||||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
|
||||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
|
||||||
m.weight.data.normal_(0.0, args.init_weight_scale)
|
|
||||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
|
||||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
|
||||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
|
||||||
if m.affine:
|
|
||||||
# NOTE: dispersion from DCGAN, why?
|
|
||||||
m.weight.data.normal_(1.0, args.init_weight_scale)
|
|
||||||
m.bias.data.fill_(0)
|
|
||||||
if args.init_weight_scale is not None:
|
|
||||||
model.apply(init_weights)
|
|
||||||
if args.adv:
|
|
||||||
adv_model.apply(init_weights)
|
|
||||||
|
|
||||||
start_epoch = 0
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
min_loss = None
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
||||||
|
|
||||||
@ -248,8 +255,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.close() # old pytorch
|
logger.close() # old pytorch
|
||||||
|
|
||||||
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
if ((min_loss is None or epoch_loss[0] < min_loss[0])
|
||||||
if good and epoch >= args.adv_start:
|
and epoch >= args.adv_start):
|
||||||
min_loss = epoch_loss
|
min_loss = epoch_loss
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
@ -265,7 +272,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
torch.save(state, state_file)
|
torch.save(state, state_file)
|
||||||
del state
|
del state
|
||||||
|
|
||||||
ckpt_link = 'checkpoint.pth'
|
|
||||||
tmp_link = '{}.pth'.format(time.time())
|
tmp_link = '{}.pth'.format(time.time())
|
||||||
os.symlink(state_file, tmp_link) # workaround to overwrite
|
os.symlink(state_file, tmp_link) # workaround to overwrite
|
||||||
os.rename(tmp_link, ckpt_link)
|
os.rename(tmp_link, ckpt_link)
|
||||||
|
@ -40,7 +40,6 @@ srun m2m.py train \
|
|||||||
--lr 0.0001 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM \
|
||||||
--cache --div-data
|
--cache --div-data
|
||||||
# --load-state checkpoint.pth \
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -41,7 +41,6 @@ srun m2m.py train \
|
|||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM \
|
||||||
--cache --div-data
|
--cache --div-data
|
||||||
# --load-state checkpoint.pth \
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -43,7 +43,6 @@ srun m2m.py train \
|
|||||||
--cache --div-data
|
--cache --div-data
|
||||||
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
|
# --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files_1,$data_root_dir/$in_dir/$val_dirs/$in_files_2" \
|
||||||
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
|
# --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_1,$data_root_dir/$tgt_dir/$val_dirs/$tgt_files_2" \
|
||||||
# --load-state checkpoint.pth \
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
@ -41,7 +41,6 @@ srun m2m.py train \
|
|||||||
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
--lr 0.0001 --adv-lr 0.0004 --batches 1 --loader-workers 0 \
|
||||||
--epochs 1024 --seed $RANDOM \
|
--epochs 1024 --seed $RANDOM \
|
||||||
--cache --div-data
|
--cache --div-data
|
||||||
# --load-state checkpoint.pth \
|
|
||||||
|
|
||||||
|
|
||||||
date
|
date
|
||||||
|
Loading…
Reference in New Issue
Block a user