Revert removal of saving/loading optimizer & scheduler states
Why did I do that?
This commit is contained in:
parent
85efb9e3a3
commit
7be3153206
1 changed files with 18 additions and 12 deletions
|
@ -144,18 +144,6 @@ def gpu_worker(local_rank, node, args):
|
|||
|
||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||
or not args.load_state):
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||
if m.affine:
|
||||
# NOTE: dispersion from DCGAN, why?
|
||||
m.weight.data.normal_(1.0, args.init_weight_std)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
if args.init_weight_std is not None:
|
||||
model.apply(init_weights)
|
||||
|
||||
|
@ -171,6 +159,9 @@ def gpu_worker(local_rank, node, args):
|
|||
load_model_state_dict(model.module, state['model'],
|
||||
strict=args.load_state_strict)
|
||||
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
scheduler.load_state_dict(state['scheduler'])
|
||||
|
||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||
|
||||
if rank == 0:
|
||||
|
@ -218,6 +209,8 @@ def gpu_worker(local_rank, node, args):
|
|||
state = {
|
||||
'epoch': epoch + 1,
|
||||
'model': model.module.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
'rng': torch.get_rng_state(),
|
||||
'min_loss': min_loss,
|
||||
}
|
||||
|
@ -421,6 +414,19 @@ def dist_init(rank, args):
|
|||
os.remove(dist_file)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||
if m.affine:
|
||||
# NOTE: dispersion from DCGAN, why?
|
||||
m.weight.data.normal_(1.0, args.init_weight_std)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
def set_requires_grad(module, requires_grad=False):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
|
Loading…
Reference in a new issue