From aeeeb966d87d7d8a81c40371426e0518da345b49 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Tue, 11 Feb 2020 17:20:42 -0500 Subject: [PATCH] Add stride to ConvBlock --- map2map/models/conv.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/map2map/models/conv.py b/map2map/models/conv.py index b30976c..e4891e2 100644 --- a/map2map/models/conv.py +++ b/map2map/models/conv.py @@ -6,9 +6,16 @@ from .swish import Swish class ConvBlock(nn.Module): """Convolution blocks of the form specified by `seq`. + + `seq` types: + 'C': convolution specified by `kernel_size` and `stride` + 'B': normalization (to be renamed to 'N') + 'A': activation + 'U': upsampling transposed convolution of kernel size 2 and stride 2 + 'D': downsampling convolution of kernel size 2 and stride 2 """ def __init__(self, in_chan, out_chan=None, mid_chan=None, - kernel_size=3, seq='CBA'): + kernel_size=3, stride=1, seq='CBA'): super().__init__() if out_chan is None: @@ -19,6 +26,7 @@ class ConvBlock(nn.Module): if mid_chan is None: self.mid_chan = max(in_chan, out_chan) self.kernel_size = kernel_size + self.stride = stride self.norm_chan = in_chan self.idx_conv = 0 @@ -37,7 +45,8 @@ class ConvBlock(nn.Module): return nn.Conv3d(in_chan, out_chan, 2, stride=2) elif l == 'C': in_chan, out_chan = self._setup_conv() - return nn.Conv3d(in_chan, out_chan, self.kernel_size) + return nn.Conv3d(in_chan, out_chan, self.kernel_size, + stride=self.stride) elif l == 'B': return nn.BatchNorm3d(self.norm_chan) #return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True) @@ -67,6 +76,8 @@ class ConvBlock(nn.Module): class ResBlock(ConvBlock): """Residual convolution blocks of the form specified by `seq`. Input is added to the residual followed by an optional activation (trailing `'A'` in `seq`). + + See `ConvBlock` for `seq` types. """ def __init__(self, in_chan, out_chan=None, mid_chan=None, seq='CBACBA'):