Add style components from stylegan2

* styled convolution with mod and demod
* pixelnorm
* linear and conv layers with equalized learning rate
This commit is contained in:
Yin Li 2021-02-26 16:21:30 -05:00
parent 4706a902f6
commit a697845933

163
map2map/models/style.py Normal file
View file

@ -0,0 +1,163 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class PixelNorm(nn.Module):
"""Pixelwise normalization after conv layers.
See ProGAN, StyleGAN.
"""
def __init__(self):
super().__init__()
def forward(self, x, eps=1e-8):
return x * torch.rsqrt(x.pow(2).mean(dim=1, keepdim=True) + eps)
class LinearElr(nn.Module):
"""Linear layer with equalized learning rate.
See ProGAN, StyleGAN, and 1706.05350
Useful at all if not for regularization(1706.05350)?
"""
def __init__(self, in_size, out_size, bias=True, act=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_size, in_size))
self.wnorm = 1 / math.sqrt(in_size)
if bias:
self.bias = nn.Parameter(torch.zeros(out_size))
else:
self.register_parameter('bias', None)
self.act = act
def forward(self, x):
x = F.linear(x, self.weight * self.wnorm, bias=self.bias)
if self.act:
x = F.leaky_relu(x, negative_slope=0.2)
return x
class ConvElr3d(nn.Module):
"""Conv3d layer with equalized learning rate.
See ProGAN, StyleGAN, and 1706.05350
Useful at all if not for regularization(1706.05350)?
"""
def __init__(self, in_chan, out_chan, kernel_size,
stride=1, padding=0, bias=True):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
)
fan_in = in_chan * kernel_size ** 3
self.wnorm = 1 / math.sqrt(fan_in)
if bias:
self.bias = nn.Parameter(torch.zeros(out_chan))
else:
self.register_parameter('bias', None)
self.stride = stride
self.padding = padding
def forward(self, x):
x = F.conv2d(
x,
self.weight * self.wnorm,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return x
class ConvMod3d(nn.Module):
"""Convolution layer with modulation and demodulation, from StyleGAN2.
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
"""
def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1,
bias=True, resample=None):
super().__init__()
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
mode='fan_in', nonlinearity='leaky_relu')
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
if resample is None:
K3 = (kernel_size,) * 3
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
self.stride = stride
self.conv = F.conv3d
elif resample == 'U':
K3 = (2,) * 3
# NOTE not clear to me why convtranspose have channels swapped
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
self.stride = 2
self.conv = F.conv_transpose3d
elif resample == 'D':
K3 = (2,) * 3
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
self.stride = 2
self.conv = F.conv3d
else:
raise ValueError('resample type {} not supported'.format(resample))
self.resample = resample
nn.init.kaiming_uniform_(
self.weight, a=math.sqrt(5),
mode='fan_in', # effectively 'fan_out' for 'D'
nonlinearity='leaky_relu',
)
if bias:
self.bias = nn.Parameter(torch.zeros(out_chan))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias, -bound, bound)
else:
self.register_parameter('bias', None)
def forward(self, x, s, eps=1e-8):
N, Cin, *DHWin = x.shape
C0, C1, *K3 = self.weight.shape
if self.resample == 'U':
Cin, Cout = C0, C1
else:
Cout, Cin = C0, C1
s = F.linear(s, self.style_weight, bias=self.style_bias)
# modulation
if self.resample == 'U':
s = s.reshape(N, Cin, 1, 1, 1, 1)
else:
s = s.reshape(N, 1, Cin, 1, 1, 1)
w = self.weight * s
# demodulation
if self.resample == 'U':
fan_in_dim = (1, 3, 4, 5)
else:
fan_in_dim = (2, 3, 4, 5)
w = w * torch.rsqrt(w.pow(2).sum(dim=fan_in_dim, keepdim=True) + eps)
w = w.reshape(N * C0, C1, *K3)
x = x.reshape(1, N * Cin, *DHWin)
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N)
_, _, *DHWout = x.shape
x = x.reshape(N, Cout, *DHWout)
return x