171 lines
4.5 KiB
Python
171 lines
4.5 KiB
Python
import warnings
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .narrow import narrow_like
|
|
from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
|
|
|
|
|
|
class ConvStyledBlock(nn.Module):
|
|
"""Convolution blocks of the form specified by `seq`.
|
|
|
|
`seq` types:
|
|
'C': convolution specified by `kernel_size` and `stride`
|
|
'B': normalization (to be renamed to 'N')
|
|
'A': activation
|
|
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
|
'D': downsampling convolution of kernel size 2 and stride 2
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
style_size,
|
|
in_chan,
|
|
out_chan=None,
|
|
mid_chan=None,
|
|
kernel_size=3,
|
|
stride=1,
|
|
seq="CBA",
|
|
):
|
|
super().__init__()
|
|
|
|
if out_chan is None:
|
|
out_chan = in_chan
|
|
|
|
self.style_size = style_size
|
|
self.in_chan = in_chan
|
|
self.out_chan = out_chan
|
|
if mid_chan is None:
|
|
self.mid_chan = max(in_chan, out_chan)
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
|
|
self.norm_chan = in_chan
|
|
self.idx_conv = 0
|
|
self.num_conv = sum([seq.count(l) for l in ["U", "D", "C"]])
|
|
|
|
layers = [self._get_layer(l) for l in seq]
|
|
|
|
self.convs = nn.ModuleList(layers)
|
|
|
|
def _get_layer(self, l):
|
|
if l == "U":
|
|
in_chan, out_chan = self._setup_conv()
|
|
return ConvStyled3d(
|
|
self.style_size, in_chan, out_chan, 2, stride=2, resample="U"
|
|
)
|
|
elif l == "D":
|
|
in_chan, out_chan = self._setup_conv()
|
|
return ConvStyled3d(
|
|
self.style_size, in_chan, out_chan, 2, stride=2, resample="D"
|
|
)
|
|
elif l == "C":
|
|
in_chan, out_chan = self._setup_conv()
|
|
return ConvStyled3d(
|
|
self.style_size, in_chan, out_chan, self.kernel_size, stride=self.stride
|
|
)
|
|
elif l == "B":
|
|
return BatchNormStyled3d(self.norm_chan)
|
|
elif l == "A":
|
|
return LeakyReLUStyled()
|
|
else:
|
|
raise ValueError("layer type {} not supported".format(l))
|
|
|
|
def _setup_conv(self):
|
|
self.idx_conv += 1
|
|
|
|
in_chan = out_chan = self.mid_chan
|
|
if self.idx_conv == 1:
|
|
in_chan = self.in_chan
|
|
if self.idx_conv == self.num_conv:
|
|
out_chan = self.out_chan
|
|
|
|
self.norm_chan = out_chan
|
|
|
|
return in_chan, out_chan
|
|
|
|
def forward(self, x, s):
|
|
for l in self.convs:
|
|
x = l(x, s)
|
|
return x
|
|
|
|
|
|
class ResStyledBlock(ConvStyledBlock):
|
|
"""Residual convolution blocks of the form specified by `seq`.
|
|
Input, via a skip connection, is added to the residual followed by an
|
|
optional activation.
|
|
|
|
The skip connection is identity if `out_chan` is omitted, otherwise it uses
|
|
a size 1 "convolution", i.e. one can trigger the latter by setting
|
|
`out_chan` even if it equals `in_chan`.
|
|
|
|
A trailing `'A'` in seq can either operate before or after the addition,
|
|
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
|
|
|
|
See `ConvStyledBlock` for `seq` types.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
style_size,
|
|
in_chan,
|
|
out_chan=None,
|
|
mid_chan=None,
|
|
kernel_size=3,
|
|
stride=1,
|
|
seq="CBACBA",
|
|
last_act=None,
|
|
):
|
|
if last_act is None:
|
|
last_act = seq[-1] == "A"
|
|
elif last_act and seq[-1] != "A":
|
|
warnings.warn(
|
|
"Disabling last_act without trailing activation in seq",
|
|
RuntimeWarning,
|
|
)
|
|
last_act = False
|
|
|
|
if last_act:
|
|
seq = seq[:-1]
|
|
|
|
super().__init__(
|
|
style_size,
|
|
in_chan,
|
|
out_chan=out_chan,
|
|
mid_chan=mid_chan,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
seq=seq,
|
|
)
|
|
|
|
if last_act:
|
|
self.act = LeakyReLUStyled()
|
|
else:
|
|
self.act = None
|
|
|
|
if out_chan is None:
|
|
self.skip = None
|
|
else:
|
|
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
|
|
|
|
if "U" in seq or "D" in seq:
|
|
raise NotImplementedError(
|
|
"upsample and downsample layers " "not supported yet"
|
|
)
|
|
|
|
def forward(self, x, s):
|
|
y = x
|
|
|
|
if self.skip is not None:
|
|
y = self.skip(y, s)
|
|
|
|
for l in self.convs:
|
|
x = l(x, s)
|
|
|
|
y = narrow_like(y, x)
|
|
x += y
|
|
|
|
if self.act is not None:
|
|
x = self.act(x, s)
|
|
|
|
return x
|