Basic cleanup
This commit is contained in:
parent
68e466c983
commit
90a14d56b9
@ -5,14 +5,14 @@ def adv_model_wrapper(module):
|
|||||||
"""Wrap an adversary model to also take lists of Tensors as input,
|
"""Wrap an adversary model to also take lists of Tensors as input,
|
||||||
to be concatenated along the batch dimension
|
to be concatenated along the batch dimension
|
||||||
"""
|
"""
|
||||||
class new_module(module):
|
class _new_module(module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
x = torch.cat(x, dim=0)
|
x = torch.cat(x, dim=0)
|
||||||
|
|
||||||
return super().forward(x)
|
return super().forward(x)
|
||||||
|
|
||||||
return new_module
|
return _new_module
|
||||||
|
|
||||||
|
|
||||||
def adv_criterion_wrapper(module):
|
def adv_criterion_wrapper(module):
|
||||||
@ -22,7 +22,7 @@ def adv_criterion_wrapper(module):
|
|||||||
* expand target shape as that of input
|
* expand target shape as that of input
|
||||||
* return a list of losses, one for each pair of input and target Tensors
|
* return a list of losses, one for each pair of input and target Tensors
|
||||||
"""
|
"""
|
||||||
class new_module(module):
|
class _new_module(module):
|
||||||
def forward(self, input, target):
|
def forward(self, input, target):
|
||||||
assert isinstance(input, torch.Tensor)
|
assert isinstance(input, torch.Tensor)
|
||||||
|
|
||||||
@ -50,4 +50,4 @@ def adv_criterion_wrapper(module):
|
|||||||
|
|
||||||
return torch.split(input, size, dim=0)
|
return torch.split(input, size, dim=0)
|
||||||
|
|
||||||
return new_module
|
return _new_module
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def power(x):
|
def power(x: torch.Tensor):
|
||||||
"""Compute power spectra of input fields
|
"""Compute power spectra of input fields
|
||||||
|
|
||||||
Each field should have batch and channel dimensions followed by spatial
|
Each field should have batch and channel dimensions followed by spatial
|
||||||
|
@ -416,7 +416,7 @@ def dist_init(rank, args):
|
|||||||
os.remove(dist_file)
|
os.remove(dist_file)
|
||||||
|
|
||||||
|
|
||||||
def init_weights(m):
|
def init_weights(m, args):
|
||||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||||
m.weight.data.normal_(0.0, args.init_weight_std)
|
m.weight.data.normal_(0.0, args.init_weight_std)
|
||||||
@ -429,13 +429,20 @@ def init_weights(m):
|
|||||||
m.bias.data.fill_(0)
|
m.bias.data.fill_(0)
|
||||||
|
|
||||||
|
|
||||||
def set_requires_grad(module, requires_grad=False):
|
def set_requires_grad(module: torch.nn.Module, requires_grad : bool =False):
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
param.requires_grad = requires_grad
|
param.requires_grad = requires_grad
|
||||||
|
|
||||||
|
|
||||||
def get_grads(model):
|
def get_grads(model: torch.nn.Module):
|
||||||
"""gradients of the weights of the first and the last layer
|
"""
|
||||||
|
Calculate the gradients of the weights of the first and the last layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model for which to calculate the gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list containing the norms of the gradients of the first and the last layer weights.
|
||||||
"""
|
"""
|
||||||
grads = list(p.grad for n, p in model.named_parameters()
|
grads = list(p.grad for n, p in model.named_parameters()
|
||||||
if '.weight' in n)
|
if '.weight' in n)
|
||||||
|
Loading…
Reference in New Issue
Block a user