Revert removal of saving/loading optimizer & scheduler states
Why did I do that?
This commit is contained in:
parent
85efb9e3a3
commit
7be3153206
@ -144,18 +144,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link)
|
||||||
or not args.load_state):
|
or not args.load_state):
|
||||||
def init_weights(m):
|
|
||||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
|
||||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
|
||||||
m.weight.data.normal_(0.0, args.init_weight_std)
|
|
||||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
|
||||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
|
||||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
|
||||||
if m.affine:
|
|
||||||
# NOTE: dispersion from DCGAN, why?
|
|
||||||
m.weight.data.normal_(1.0, args.init_weight_std)
|
|
||||||
m.bias.data.fill_(0)
|
|
||||||
|
|
||||||
if args.init_weight_std is not None:
|
if args.init_weight_std is not None:
|
||||||
model.apply(init_weights)
|
model.apply(init_weights)
|
||||||
|
|
||||||
@ -171,6 +159,9 @@ def gpu_worker(local_rank, node, args):
|
|||||||
load_model_state_dict(model.module, state['model'],
|
load_model_state_dict(model.module, state['model'],
|
||||||
strict=args.load_state_strict)
|
strict=args.load_state_strict)
|
||||||
|
|
||||||
|
optimizer.load_state_dict(state['optimizer'])
|
||||||
|
scheduler.load_state_dict(state['scheduler'])
|
||||||
|
|
||||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -218,6 +209,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
state = {
|
state = {
|
||||||
'epoch': epoch + 1,
|
'epoch': epoch + 1,
|
||||||
'model': model.module.state_dict(),
|
'model': model.module.state_dict(),
|
||||||
|
'optimizer': optimizer.state_dict(),
|
||||||
|
'scheduler': scheduler.state_dict(),
|
||||||
'rng': torch.get_rng_state(),
|
'rng': torch.get_rng_state(),
|
||||||
'min_loss': min_loss,
|
'min_loss': min_loss,
|
||||||
}
|
}
|
||||||
@ -421,6 +414,19 @@ def dist_init(rank, args):
|
|||||||
os.remove(dist_file)
|
os.remove(dist_file)
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||||
|
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||||
|
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||||
|
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||||
|
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||||
|
if m.affine:
|
||||||
|
# NOTE: dispersion from DCGAN, why?
|
||||||
|
m.weight.data.normal_(1.0, args.init_weight_std)
|
||||||
|
m.bias.data.fill_(0)
|
||||||
|
|
||||||
|
|
||||||
def set_requires_grad(module, requires_grad=False):
|
def set_requires_grad(module, requires_grad=False):
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
param.requires_grad = requires_grad
|
param.requires_grad = requires_grad
|
||||||
|
Loading…
Reference in New Issue
Block a user