Basic cleanup

This commit is contained in:
Guilhem Lavaux 2024-04-03 09:44:01 +02:00
parent 68e466c983
commit 90a14d56b9
3 changed files with 16 additions and 9 deletions

View File

@ -5,14 +5,14 @@ def adv_model_wrapper(module):
"""Wrap an adversary model to also take lists of Tensors as input, """Wrap an adversary model to also take lists of Tensors as input,
to be concatenated along the batch dimension to be concatenated along the batch dimension
""" """
class new_module(module): class _new_module(module):
def forward(self, x): def forward(self, x):
if not isinstance(x, torch.Tensor): if not isinstance(x, torch.Tensor):
x = torch.cat(x, dim=0) x = torch.cat(x, dim=0)
return super().forward(x) return super().forward(x)
return new_module return _new_module
def adv_criterion_wrapper(module): def adv_criterion_wrapper(module):
@ -22,7 +22,7 @@ def adv_criterion_wrapper(module):
* expand target shape as that of input * expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors * return a list of losses, one for each pair of input and target Tensors
""" """
class new_module(module): class _new_module(module):
def forward(self, input, target): def forward(self, input, target):
assert isinstance(input, torch.Tensor) assert isinstance(input, torch.Tensor)
@ -50,4 +50,4 @@ def adv_criterion_wrapper(module):
return torch.split(input, size, dim=0) return torch.split(input, size, dim=0)
return new_module return _new_module

View File

@ -1,7 +1,7 @@
import torch import torch
def power(x): def power(x: torch.Tensor):
"""Compute power spectra of input fields """Compute power spectra of input fields
Each field should have batch and channel dimensions followed by spatial Each field should have batch and channel dimensions followed by spatial

View File

@ -416,7 +416,7 @@ def dist_init(rank, args):
os.remove(dist_file) os.remove(dist_file)
def init_weights(m): def init_weights(m, args):
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
m.weight.data.normal_(0.0, args.init_weight_std) m.weight.data.normal_(0.0, args.init_weight_std)
@ -429,13 +429,20 @@ def init_weights(m):
m.bias.data.fill_(0) m.bias.data.fill_(0)
def set_requires_grad(module, requires_grad=False): def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False):
for param in module.parameters(): for param in module.parameters():
param.requires_grad = requires_grad param.requires_grad = requires_grad
def get_grads(model): def get_grads(model: torch.nn.Module):
"""gradients of the weights of the first and the last layer """
Calculate the gradients of the weights of the first and the last layer.
Args:
model: The model for which to calculate the gradients.
Returns:
A list containing the norms of the gradients of the first and the last layer weights.
""" """
grads = list(p.grad for n, p in model.named_parameters() grads = list(p.grad for n, p in model.named_parameters()
if '.weight' in n) if '.weight' in n)