From 9e039c0407bb165514bc8ead3f16e8a2fe614890 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Mon, 9 Dec 2019 20:59:39 -0500 Subject: [PATCH] Add variable skip connection in ResBlock --- map2map/models/conv.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/map2map/models/conv.py b/map2map/models/conv.py index 834d761..6b71dc3 100644 --- a/map2map/models/conv.py +++ b/map2map/models/conv.py @@ -59,14 +59,29 @@ class ResBlock(ConvBlock): """Residual convolution blocks of the form specified by `seq`. Input is added to the residual followed by an activation. """ - def __init__(self, channels, seq='CBACB'): - super().__init__(in_channels=channels, out_channels=channels, seq=seq) + def __init__(self, in_channels, out_channels=None, mid_channels=None, + seq='CBACB'): + if 'U' in seq or 'D' in seq: + raise NotImplementedError('upsample and downsample layers ' + 'not supported yet') + + if out_channels is None: + out_channels = in_channels + self.skip = None + else: + self.skip = nn.Conv3d(in_channels, out_channels, 1) + + super().__init__(in_channels, out_channels, mid_channels=mid_channels, + seq=seq) self.act = nn.PReLU() def forward(self, x): y = x + if self.skip is not None: + y = self.skip(y) + x = self.convs(x) y = narrow_like(y, x)