Basic cleanup

This commit is contained in:
Guilhem Lavaux 2024-04-03 09:44:01 +02:00
parent 68e466c983
commit 90a14d56b9
3 changed files with 16 additions and 9 deletions

View file

@ -5,14 +5,14 @@ def adv_model_wrapper(module):
"""Wrap an adversary model to also take lists of Tensors as input,
to be concatenated along the batch dimension
"""
class new_module(module):
class _new_module(module):
def forward(self, x):
if not isinstance(x, torch.Tensor):
x = torch.cat(x, dim=0)
return super().forward(x)
return new_module
return _new_module
def adv_criterion_wrapper(module):
@ -22,7 +22,7 @@ def adv_criterion_wrapper(module):
* expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors
"""
class new_module(module):
class _new_module(module):
def forward(self, input, target):
assert isinstance(input, torch.Tensor)
@ -50,4 +50,4 @@ def adv_criterion_wrapper(module):
return torch.split(input, size, dim=0)
return new_module
return _new_module

View file

@ -1,7 +1,7 @@
import torch
def power(x):
def power(x: torch.Tensor):
"""Compute power spectra of input fields
Each field should have batch and channel dimensions followed by spatial

View file

@ -416,7 +416,7 @@ def dist_init(rank, args):
os.remove(dist_file)
def init_weights(m):
def init_weights(m, args):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_std)
@ -429,13 +429,20 @@ def init_weights(m):
m.bias.data.fill_(0)
def set_requires_grad(module, requires_grad=False):
def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False):
for param in module.parameters():
param.requires_grad = requires_grad
def get_grads(model):
"""gradients of the weights of the first and the last layer
def get_grads(model: torch.nn.Module):
"""
Calculate the gradients of the weights of the first and the last layer.
Args:
model: The model for which to calculate the gradients.
Returns:
A list containing the norms of the gradients of the first and the last layer weights.
"""
grads = list(p.grad for n, p in model.named_parameters()
if '.weight' in n)