Add kernel_size and stride to ResBlock

This commit is contained in:
Yin Li 2021-03-17 14:01:00 -04:00
parent fd1cdb0ce7
commit 55b1a72ef4

View File

@ -90,7 +90,7 @@ class ResBlock(ConvBlock):
See `ConvBlock` for `seq` types. See `ConvBlock` for `seq` types.
""" """
def __init__(self, in_chan, out_chan=None, mid_chan=None, def __init__(self, in_chan, out_chan=None, mid_chan=None,
seq='CBACBA', last_act=None): kernel_size=3, stride=1, seq='CBACBA', last_act=None):
if last_act is None: if last_act is None:
last_act = seq[-1] == 'A' last_act = seq[-1] == 'A'
elif last_act and seq[-1] != 'A': elif last_act and seq[-1] != 'A':
@ -103,7 +103,8 @@ class ResBlock(ConvBlock):
if last_act: if last_act:
seq = seq[:-1] seq = seq[:-1]
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq) super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan,
kernel_size=kernel_size, stride=stride, seq=seq)
if last_act: if last_act:
self.act = nn.LeakyReLU() self.act = nn.LeakyReLU()