From 55b1a72ef4acc3ce9a4d970d1323f1cdc3844165 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 17 Mar 2021 14:01:00 -0400 Subject: [PATCH] Add kernel_size and stride to ResBlock --- map2map/models/conv.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/map2map/models/conv.py b/map2map/models/conv.py index 523d0ff..08d1f6f 100644 --- a/map2map/models/conv.py +++ b/map2map/models/conv.py @@ -90,7 +90,7 @@ class ResBlock(ConvBlock): See `ConvBlock` for `seq` types. """ def __init__(self, in_chan, out_chan=None, mid_chan=None, - seq='CBACBA', last_act=None): + kernel_size=3, stride=1, seq='CBACBA', last_act=None): if last_act is None: last_act = seq[-1] == 'A' elif last_act and seq[-1] != 'A': @@ -103,7 +103,8 @@ class ResBlock(ConvBlock): if last_act: seq = seq[:-1] - super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq) + super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, + kernel_size=kernel_size, stride=stride, seq=seq) if last_act: self.act = nn.LeakyReLU()