Add stride to ConvBlock
This commit is contained in:
parent
cc0efd28ec
commit
aeeeb966d8
@ -6,9 +6,16 @@ from .swish import Swish
|
|||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolution blocks of the form specified by `seq`.
|
"""Convolution blocks of the form specified by `seq`.
|
||||||
|
|
||||||
|
`seq` types:
|
||||||
|
'C': convolution specified by `kernel_size` and `stride`
|
||||||
|
'B': normalization (to be renamed to 'N')
|
||||||
|
'A': activation
|
||||||
|
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||||
|
'D': downsampling convolution of kernel size 2 and stride 2
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||||
kernel_size=3, seq='CBA'):
|
kernel_size=3, stride=1, seq='CBA'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if out_chan is None:
|
if out_chan is None:
|
||||||
@ -19,6 +26,7 @@ class ConvBlock(nn.Module):
|
|||||||
if mid_chan is None:
|
if mid_chan is None:
|
||||||
self.mid_chan = max(in_chan, out_chan)
|
self.mid_chan = max(in_chan, out_chan)
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
self.norm_chan = in_chan
|
self.norm_chan = in_chan
|
||||||
self.idx_conv = 0
|
self.idx_conv = 0
|
||||||
@ -37,7 +45,8 @@ class ConvBlock(nn.Module):
|
|||||||
return nn.Conv3d(in_chan, out_chan, 2, stride=2)
|
return nn.Conv3d(in_chan, out_chan, 2, stride=2)
|
||||||
elif l == 'C':
|
elif l == 'C':
|
||||||
in_chan, out_chan = self._setup_conv()
|
in_chan, out_chan = self._setup_conv()
|
||||||
return nn.Conv3d(in_chan, out_chan, self.kernel_size)
|
return nn.Conv3d(in_chan, out_chan, self.kernel_size,
|
||||||
|
stride=self.stride)
|
||||||
elif l == 'B':
|
elif l == 'B':
|
||||||
return nn.BatchNorm3d(self.norm_chan)
|
return nn.BatchNorm3d(self.norm_chan)
|
||||||
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
|
#return nn.InstanceNorm3d(self.norm_chan, affine=True, track_running_stats=True)
|
||||||
@ -67,6 +76,8 @@ class ConvBlock(nn.Module):
|
|||||||
class ResBlock(ConvBlock):
|
class ResBlock(ConvBlock):
|
||||||
"""Residual convolution blocks of the form specified by `seq`. Input is added
|
"""Residual convolution blocks of the form specified by `seq`. Input is added
|
||||||
to the residual followed by an optional activation (trailing `'A'` in `seq`).
|
to the residual followed by an optional activation (trailing `'A'` in `seq`).
|
||||||
|
|
||||||
|
See `ConvBlock` for `seq` types.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||||
seq='CBACBA'):
|
seq='CBACBA'):
|
||||||
|
Loading…
Reference in New Issue
Block a user