diff --git a/map2map/models/styled_conv.py b/map2map/models/styled_conv.py index 69602b4..2043537 100644 --- a/map2map/models/styled_conv.py +++ b/map2map/models/styled_conv.py @@ -3,10 +3,9 @@ import torch import torch.nn as nn from .narrow import narrow_like -from .swish import Swish - from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled + class ConvStyledBlock(nn.Module): """Convolution blocks of the form specified by `seq`. @@ -58,7 +57,7 @@ class ConvStyledBlock(nn.Module): elif l == 'A': return LeakyReLUStyled() else: - raise NotImplementedError('layer type {} not supported'.format(l)) + raise ValueError('layer type {} not supported'.format(l)) def _setup_conv(self): self.idx_conv += 1 @@ -94,7 +93,7 @@ class ResStyledBlock(ConvStyledBlock): See `ConvStyledBlock` for `seq` types. """ def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None, - seq='CBACBA', last_act=None): + kernel_size=3, stride=1, seq='CBACBA', last_act=None): if last_act is None: last_act = seq[-1] == 'A' elif last_act and seq[-1] != 'A': @@ -107,7 +106,8 @@ class ResStyledBlock(ConvStyledBlock): if last_act: seq = seq[:-1] - super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq) + super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, + kernel_size=kernel_size, stride=stride, seq=seq) if last_act: self.act = LeakyReLUStyled() diff --git a/map2map/models/styled_vnet.py b/map2map/models/styled_vnet.py index d9025c5..151aa59 100644 --- a/map2map/models/styled_vnet.py +++ b/map2map/models/styled_vnet.py @@ -27,7 +27,10 @@ class StyledVNet(nn.Module): self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA') self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC') - self.bypass = in_chan == out_chan + if bypass is None: + self.bypass = in_chan == out_chan + else: + self.bypass = bypass def forward(self, x, s): if self.bypass: