Revert removal of saving/loading optimizer & scheduler states

Why did I do that?
This commit is contained in:
Yin Li 2020-09-12 16:04:09 -04:00
parent 85efb9e3a3
commit 7be3153206

View file

@ -144,18 +144,6 @@ def gpu_worker(local_rank, node, args):
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
or not args.load_state):
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_std)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_std)
m.bias.data.fill_(0)
if args.init_weight_std is not None:
model.apply(init_weights)
@ -171,6 +159,9 @@ def gpu_worker(local_rank, node, args):
load_model_state_dict(model.module, state['model'],
strict=args.load_state_strict)
optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler'])
torch.set_rng_state(state['rng'].cpu()) # move rng state back
if rank == 0:
@ -218,6 +209,8 @@ def gpu_worker(local_rank, node, args):
state = {
'epoch': epoch + 1,
'model': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'rng': torch.get_rng_state(),
'min_loss': min_loss,
}
@ -421,6 +414,19 @@ def dist_init(rank, args):
os.remove(dist_file)
def init_weights(m):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_std)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.affine:
# NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_std)
m.bias.data.fill_(0)
def set_requires_grad(module, requires_grad=False):
for param in module.parameters():
param.requires_grad = requires_grad