Add kernel_size and stride to ResBlock
This commit is contained in:
parent
fd1cdb0ce7
commit
55b1a72ef4
@ -90,7 +90,7 @@ class ResBlock(ConvBlock):
|
|||||||
See `ConvBlock` for `seq` types.
|
See `ConvBlock` for `seq` types.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||||
seq='CBACBA', last_act=None):
|
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
||||||
if last_act is None:
|
if last_act is None:
|
||||||
last_act = seq[-1] == 'A'
|
last_act = seq[-1] == 'A'
|
||||||
elif last_act and seq[-1] != 'A':
|
elif last_act and seq[-1] != 'A':
|
||||||
@ -103,7 +103,8 @@ class ResBlock(ConvBlock):
|
|||||||
if last_act:
|
if last_act:
|
||||||
seq = seq[:-1]
|
seq = seq[:-1]
|
||||||
|
|
||||||
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
||||||
|
kernel_size=kernel_size, stride=stride, seq=seq)
|
||||||
|
|
||||||
if last_act:
|
if last_act:
|
||||||
self.act = nn.LeakyReLU()
|
self.act = nn.LeakyReLU()
|
||||||
|
Loading…
Reference in New Issue
Block a user