Add variable skip connection in ResBlock
This commit is contained in:
parent
139caabe31
commit
9e039c0407
1 changed files with 17 additions and 2 deletions
|
@ -59,14 +59,29 @@ class ResBlock(ConvBlock):
|
|||
"""Residual convolution blocks of the form specified by `seq`. Input is
|
||||
added to the residual followed by an activation.
|
||||
"""
|
||||
def __init__(self, channels, seq='CBACB'):
|
||||
super().__init__(in_channels=channels, out_channels=channels, seq=seq)
|
||||
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
||||
seq='CBACB'):
|
||||
if 'U' in seq or 'D' in seq:
|
||||
raise NotImplementedError('upsample and downsample layers '
|
||||
'not supported yet')
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
self.skip = None
|
||||
else:
|
||||
self.skip = nn.Conv3d(in_channels, out_channels, 1)
|
||||
|
||||
super().__init__(in_channels, out_channels, mid_channels=mid_channels,
|
||||
seq=seq)
|
||||
|
||||
self.act = nn.PReLU()
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
|
||||
if self.skip is not None:
|
||||
y = self.skip(y)
|
||||
|
||||
x = self.convs(x)
|
||||
|
||||
y = narrow_like(y, x)
|
||||
|
|
Loading…
Reference in a new issue