Add style components from stylegan2
* styled convolution with mod and demod * pixelnorm * linear and conv layers with equalized learning rate
This commit is contained in:
parent
4706a902f6
commit
a697845933
163
map2map/models/style.py
Normal file
163
map2map/models/style.py
Normal file
@ -0,0 +1,163 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
"""Pixelwise normalization after conv layers.
|
||||
|
||||
See ProGAN, StyleGAN.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, eps=1e-8):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=1, keepdim=True) + eps)
|
||||
|
||||
|
||||
class LinearElr(nn.Module):
|
||||
"""Linear layer with equalized learning rate.
|
||||
|
||||
See ProGAN, StyleGAN, and 1706.05350
|
||||
|
||||
Useful at all if not for regularization(1706.05350)?
|
||||
"""
|
||||
def __init__(self, in_size, out_size, bias=True, act=None):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_size, in_size))
|
||||
self.wnorm = 1 / math.sqrt(in_size)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_size))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.act = act
|
||||
|
||||
def forward(self, x):
|
||||
x = F.linear(x, self.weight * self.wnorm, bias=self.bias)
|
||||
|
||||
if self.act:
|
||||
x = F.leaky_relu(x, negative_slope=0.2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvElr3d(nn.Module):
|
||||
"""Conv3d layer with equalized learning rate.
|
||||
|
||||
See ProGAN, StyleGAN, and 1706.05350
|
||||
|
||||
Useful at all if not for regularization(1706.05350)?
|
||||
"""
|
||||
def __init__(self, in_chan, out_chan, kernel_size,
|
||||
stride=1, padding=0, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
|
||||
)
|
||||
fan_in = in_chan * kernel_size ** 3
|
||||
self.wnorm = 1 / math.sqrt(fan_in)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_chan))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
def forward(self, x):
|
||||
x = F.conv2d(
|
||||
x,
|
||||
self.weight * self.wnorm,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvMod3d(nn.Module):
|
||||
"""Convolution layer with modulation and demodulation, from StyleGAN2.
|
||||
|
||||
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1,
|
||||
bias=True, resample=None):
|
||||
super().__init__()
|
||||
|
||||
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
|
||||
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
|
||||
mode='fan_in', nonlinearity='leaky_relu')
|
||||
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
||||
|
||||
if resample is None:
|
||||
K3 = (kernel_size,) * 3
|
||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||
self.stride = stride
|
||||
self.conv = F.conv3d
|
||||
elif resample == 'U':
|
||||
K3 = (2,) * 3
|
||||
# NOTE not clear to me why convtranspose have channels swapped
|
||||
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
|
||||
self.stride = 2
|
||||
self.conv = F.conv_transpose3d
|
||||
elif resample == 'D':
|
||||
K3 = (2,) * 3
|
||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||
self.stride = 2
|
||||
self.conv = F.conv3d
|
||||
else:
|
||||
raise ValueError('resample type {} not supported'.format(resample))
|
||||
self.resample = resample
|
||||
|
||||
nn.init.kaiming_uniform_(
|
||||
self.weight, a=math.sqrt(5),
|
||||
mode='fan_in', # effectively 'fan_out' for 'D'
|
||||
nonlinearity='leaky_relu',
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_chan))
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x, s, eps=1e-8):
|
||||
N, Cin, *DHWin = x.shape
|
||||
C0, C1, *K3 = self.weight.shape
|
||||
if self.resample == 'U':
|
||||
Cin, Cout = C0, C1
|
||||
else:
|
||||
Cout, Cin = C0, C1
|
||||
|
||||
s = F.linear(s, self.style_weight, bias=self.style_bias)
|
||||
|
||||
# modulation
|
||||
if self.resample == 'U':
|
||||
s = s.reshape(N, Cin, 1, 1, 1, 1)
|
||||
else:
|
||||
s = s.reshape(N, 1, Cin, 1, 1, 1)
|
||||
w = self.weight * s
|
||||
|
||||
# demodulation
|
||||
if self.resample == 'U':
|
||||
fan_in_dim = (1, 3, 4, 5)
|
||||
else:
|
||||
fan_in_dim = (2, 3, 4, 5)
|
||||
w = w * torch.rsqrt(w.pow(2).sum(dim=fan_in_dim, keepdim=True) + eps)
|
||||
|
||||
w = w.reshape(N * C0, C1, *K3)
|
||||
x = x.reshape(1, N * Cin, *DHWin)
|
||||
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N)
|
||||
_, _, *DHWout = x.shape
|
||||
x = x.reshape(N, Cout, *DHWout)
|
||||
|
||||
return x
|
Loading…
Reference in New Issue
Block a user