Basic cleanup
This commit is contained in:
parent
68e466c983
commit
90a14d56b9
3 changed files with 16 additions and 9 deletions
|
@ -5,14 +5,14 @@ def adv_model_wrapper(module):
|
|||
"""Wrap an adversary model to also take lists of Tensors as input,
|
||||
to be concatenated along the batch dimension
|
||||
"""
|
||||
class new_module(module):
|
||||
class _new_module(module):
|
||||
def forward(self, x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
return new_module
|
||||
return _new_module
|
||||
|
||||
|
||||
def adv_criterion_wrapper(module):
|
||||
|
@ -22,7 +22,7 @@ def adv_criterion_wrapper(module):
|
|||
* expand target shape as that of input
|
||||
* return a list of losses, one for each pair of input and target Tensors
|
||||
"""
|
||||
class new_module(module):
|
||||
class _new_module(module):
|
||||
def forward(self, input, target):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
|
||||
|
@ -50,4 +50,4 @@ def adv_criterion_wrapper(module):
|
|||
|
||||
return torch.split(input, size, dim=0)
|
||||
|
||||
return new_module
|
||||
return _new_module
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
|
||||
def power(x):
|
||||
def power(x: torch.Tensor):
|
||||
"""Compute power spectra of input fields
|
||||
|
||||
Each field should have batch and channel dimensions followed by spatial
|
||||
|
|
|
@ -416,7 +416,7 @@ def dist_init(rank, args):
|
|||
os.remove(dist_file)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
def init_weights(m, args):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||
|
@ -429,13 +429,20 @@ def init_weights(m):
|
|||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
def set_requires_grad(module, requires_grad=False):
|
||||
def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False):
|
||||
for param in module.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
|
||||
def get_grads(model):
|
||||
"""gradients of the weights of the first and the last layer
|
||||
def get_grads(model: torch.nn.Module):
|
||||
"""
|
||||
Calculate the gradients of the weights of the first and the last layer.
|
||||
|
||||
Args:
|
||||
model: The model for which to calculate the gradients.
|
||||
|
||||
Returns:
|
||||
A list containing the norms of the gradients of the first and the last layer weights.
|
||||
"""
|
||||
grads = list(p.grad for n, p in model.named_parameters()
|
||||
if '.weight' in n)
|
||||
|
|
Loading…
Reference in a new issue