diff --git a/map2map/models/conv.py b/map2map/models/conv.py index 523d0ff..08d1f6f 100644 --- a/map2map/models/conv.py +++ b/map2map/models/conv.py @@ -90,7 +90,7 @@ class ResBlock(ConvBlock): See `ConvBlock` for `seq` types. """ def __init__(self, in_chan, out_chan=None, mid_chan=None, - seq='CBACBA', last_act=None): + kernel_size=3, stride=1, seq='CBACBA', last_act=None): if last_act is None: last_act = seq[-1] == 'A' elif last_act and seq[-1] != 'A': @@ -103,7 +103,8 @@ class ResBlock(ConvBlock): if last_act: seq = seq[:-1] - super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq) + super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, + kernel_size=kernel_size, stride=stride, seq=seq) if last_act: self.act = nn.LeakyReLU()