diff --git a/map2map/models/adversary.py b/map2map/models/adversary.py index 1a8d0d3..4a20509 100644 --- a/map2map/models/adversary.py +++ b/map2map/models/adversary.py @@ -5,14 +5,14 @@ def adv_model_wrapper(module): """Wrap an adversary model to also take lists of Tensors as input, to be concatenated along the batch dimension """ - class new_module(module): + class _new_module(module): def forward(self, x): if not isinstance(x, torch.Tensor): x = torch.cat(x, dim=0) return super().forward(x) - return new_module + return _new_module def adv_criterion_wrapper(module): @@ -22,7 +22,7 @@ def adv_criterion_wrapper(module): * expand target shape as that of input * return a list of losses, one for each pair of input and target Tensors """ - class new_module(module): + class _new_module(module): def forward(self, input, target): assert isinstance(input, torch.Tensor) @@ -50,4 +50,4 @@ def adv_criterion_wrapper(module): return torch.split(input, size, dim=0) - return new_module + return _new_module diff --git a/map2map/models/power.py b/map2map/models/power.py index 7aaf249..19e32b7 100644 --- a/map2map/models/power.py +++ b/map2map/models/power.py @@ -1,7 +1,7 @@ import torch -def power(x): +def power(x: torch.Tensor): """Compute power spectra of input fields Each field should have batch and channel dimensions followed by spatial diff --git a/map2map/train.py b/map2map/train.py index db80b16..d4fb380 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -416,7 +416,7 @@ def dist_init(rank, args): os.remove(dist_file) -def init_weights(m): +def init_weights(m, args): if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): m.weight.data.normal_(0.0, args.init_weight_std) @@ -429,13 +429,20 @@ def init_weights(m): m.bias.data.fill_(0) -def set_requires_grad(module, requires_grad=False): +def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False): for param in module.parameters(): param.requires_grad = requires_grad -def get_grads(model): - """gradients of the weights of the first and the last layer +def get_grads(model: torch.nn.Module): + """ + Calculate the gradients of the weights of the first and the last layer. + + Args: + model: The model for which to calculate the gradients. + + Returns: + A list containing the norms of the gradients of the first and the last layer weights. """ grads = list(p.grad for n, p in model.named_parameters() if '.weight' in n)