Fix spectral norm interference with gradient visualization

spectral norm introduces weight_orig, so that not all weight parameters
ends with ".weight"
This commit is contained in:
Yin Li 2020-03-04 16:21:00 -05:00
parent 33fd13e4ac
commit 5fbf447e2d

View File

@ -364,7 +364,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
# gradients of the weights of the first and the last layer
grads = list(p.grad for n, p in model.named_parameters()
if n.endswith('weight'))
if '.weight' in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad', {
@ -373,7 +373,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
}, global_step=batch)
if args.adv and epoch >= args.adv_start:
grads = list(p.grad for n, p in adv_model.named_parameters()
if n.endswith('weight'))
if '.weight' in n)
grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad/adv', {