diff --git a/map2map/train.py b/map2map/train.py index 1460f5e..bc23e3f 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -144,18 +144,6 @@ def gpu_worker(local_rank, node, args): if (args.load_state == ckpt_link and not os.path.isfile(ckpt_link) or not args.load_state): - def init_weights(m): - if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, - nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): - m.weight.data.normal_(0.0, args.init_weight_std) - elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, - nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm, - nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): - if m.affine: - # NOTE: dispersion from DCGAN, why? - m.weight.data.normal_(1.0, args.init_weight_std) - m.bias.data.fill_(0) - if args.init_weight_std is not None: model.apply(init_weights) @@ -171,6 +159,9 @@ def gpu_worker(local_rank, node, args): load_model_state_dict(model.module, state['model'], strict=args.load_state_strict) + optimizer.load_state_dict(state['optimizer']) + scheduler.load_state_dict(state['scheduler']) + torch.set_rng_state(state['rng'].cpu()) # move rng state back if rank == 0: @@ -218,6 +209,8 @@ def gpu_worker(local_rank, node, args): state = { 'epoch': epoch + 1, 'model': model.module.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scheduler': scheduler.state_dict(), 'rng': torch.get_rng_state(), 'min_loss': min_loss, } @@ -421,6 +414,19 @@ def dist_init(rank, args): os.remove(dist_file) +def init_weights(m): + if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, + nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + m.weight.data.normal_(0.0, args.init_weight_std) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): + if m.affine: + # NOTE: dispersion from DCGAN, why? + m.weight.data.normal_(1.0, args.init_weight_std) + m.bias.data.fill_(0) + + def set_requires_grad(module, requires_grad=False): for param in module.parameters(): param.requires_grad = requires_grad