Add variable skip connection in ResBlock
This commit is contained in:
parent
139caabe31
commit
9e039c0407
@ -59,14 +59,29 @@ class ResBlock(ConvBlock):
|
|||||||
"""Residual convolution blocks of the form specified by `seq`. Input is
|
"""Residual convolution blocks of the form specified by `seq`. Input is
|
||||||
added to the residual followed by an activation.
|
added to the residual followed by an activation.
|
||||||
"""
|
"""
|
||||||
def __init__(self, channels, seq='CBACB'):
|
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
||||||
super().__init__(in_channels=channels, out_channels=channels, seq=seq)
|
seq='CBACB'):
|
||||||
|
if 'U' in seq or 'D' in seq:
|
||||||
|
raise NotImplementedError('upsample and downsample layers '
|
||||||
|
'not supported yet')
|
||||||
|
|
||||||
|
if out_channels is None:
|
||||||
|
out_channels = in_channels
|
||||||
|
self.skip = None
|
||||||
|
else:
|
||||||
|
self.skip = nn.Conv3d(in_channels, out_channels, 1)
|
||||||
|
|
||||||
|
super().__init__(in_channels, out_channels, mid_channels=mid_channels,
|
||||||
|
seq=seq)
|
||||||
|
|
||||||
self.act = nn.PReLU()
|
self.act = nn.PReLU()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x
|
y = x
|
||||||
|
|
||||||
|
if self.skip is not None:
|
||||||
|
y = self.skip(y)
|
||||||
|
|
||||||
x = self.convs(x)
|
x = self.convs(x)
|
||||||
|
|
||||||
y = narrow_like(y, x)
|
y = narrow_like(y, x)
|
||||||
|
Loading…
Reference in New Issue
Block a user