Add variable skip connection in ResBlock

This commit is contained in:
Yin Li 2019-12-09 20:59:39 -05:00
parent 139caabe31
commit 9e039c0407

View File

@ -59,14 +59,29 @@ class ResBlock(ConvBlock):
"""Residual convolution blocks of the form specified by `seq`. Input is """Residual convolution blocks of the form specified by `seq`. Input is
added to the residual followed by an activation. added to the residual followed by an activation.
""" """
def __init__(self, channels, seq='CBACB'): def __init__(self, in_channels, out_channels=None, mid_channels=None,
super().__init__(in_channels=channels, out_channels=channels, seq=seq) seq='CBACB'):
if 'U' in seq or 'D' in seq:
raise NotImplementedError('upsample and downsample layers '
'not supported yet')
if out_channels is None:
out_channels = in_channels
self.skip = None
else:
self.skip = nn.Conv3d(in_channels, out_channels, 1)
super().__init__(in_channels, out_channels, mid_channels=mid_channels,
seq=seq)
self.act = nn.PReLU() self.act = nn.PReLU()
def forward(self, x): def forward(self, x):
y = x y = x
if self.skip is not None:
y = self.skip(y)
x = self.convs(x) x = self.convs(x)
y = narrow_like(y, x) y = narrow_like(y, x)