diff --git a/map2map/models/conv.py b/map2map/models/conv.py index 4650992..38335b6 100644 --- a/map2map/models/conv.py +++ b/map2map/models/conv.py @@ -1,17 +1,24 @@ import torch import torch.nn as nn +from .swish import Swish + class ConvBlock(nn.Module): """Convolution blocks of the form specified by `seq`. """ - def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBA'): + def __init__(self, in_channels, out_channels=None, mid_channels=None, + kernel_size=3, seq='CBA'): super().__init__() + if out_channels is None: + out_channels = in_channels + self.in_channels = in_channels self.out_channels = out_channels if mid_channels is None: self.mid_channels = max(in_channels, out_channels) + self.kernel_size = kernel_size self.bn_channels = in_channels self.idx_conv = 0 @@ -30,11 +37,11 @@ class ConvBlock(nn.Module): return nn.Conv3d(in_channels, out_channels, 2, stride=2) elif l == 'C': in_channels, out_channels = self._setup_conv() - return nn.Conv3d(in_channels, out_channels, 3) + return nn.Conv3d(in_channels, out_channels, self.kernel_size) elif l == 'B': return nn.BatchNorm3d(self.bn_channels) elif l == 'A': - return nn.PReLU() + return Swish() else: raise NotImplementedError('layer type {} not supported'.format(l)) @@ -56,13 +63,16 @@ class ConvBlock(nn.Module): class ResBlock(ConvBlock): - """Residual convolution blocks of the form specified by `seq`. Input is - added to the residual followed by an activation (trailing `'A'` in `seq`). + """Residual convolution blocks of the form specified by `seq`. Input is added + to the residual followed by an optional activation (trailing `'A'` in `seq`). """ def __init__(self, in_channels, out_channels=None, mid_channels=None, seq='CBACBA'): + super().__init__(in_channels, out_channels=out_channels, + mid_channels=mid_channels, + seq=seq[:-1] if seq[-1] == 'A' else seq) + if out_channels is None: - out_channels = in_channels self.skip = None else: self.skip = nn.Conv3d(in_channels, out_channels, 1) @@ -71,12 +81,10 @@ class ResBlock(ConvBlock): raise NotImplementedError('upsample and downsample layers ' 'not supported yet') - assert seq[-1] == 'A', 'block must end with activation' - - super().__init__(in_channels, out_channels, mid_channels=mid_channels, - seq=seq[:-1]) - - self.act = nn.PReLU() + if seq[-1] == 'A': + self.act = Swish() + else: + self.act = None def forward(self, x): y = x @@ -89,7 +97,10 @@ class ResBlock(ConvBlock): y = narrow_like(y, x) x += y - return self.act(x) + if self.act is not None: + x = self.act(x) + + return x def narrow_like(a, b): diff --git a/map2map/models/swish.py b/map2map/models/swish.py new file mode 100644 index 0000000..c9eda0d --- /dev/null +++ b/map2map/models/swish.py @@ -0,0 +1,20 @@ +import torch + + +class SwishFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input): + result = input * torch.sigmoid(input) + ctx.save_for_backward(input) + return result + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_variables + sigmoid = torch.sigmoid(input) + return grad_output * (sigmoid * (1 + input * (1 - sigmoid))) + + +class Swish(torch.nn.Module): + def forward(self, input): + return SwishFunction.apply(input) diff --git a/map2map/models/unet.py b/map2map/models/unet.py index abd1054..5ae9c6e 100644 --- a/map2map/models/unet.py +++ b/map2map/models/unet.py @@ -9,15 +9,15 @@ class UNet(nn.Module): super().__init__() self.conv_0l = ConvBlock(in_channels, 64, seq='CAC') - self.down_0l = ConvBlock(64, 64, seq='BADBA') - self.conv_1l = ConvBlock(64, 64, seq='CBAC') - self.down_1l = ConvBlock(64, 64, seq='BADBA') + self.down_0l = ConvBlock(64, seq='BADBA') + self.conv_1l = ConvBlock(64, seq='CBAC') + self.down_1l = ConvBlock(64, seq='BADBA') - self.conv_2c = ConvBlock(64, 64, seq='CBAC') + self.conv_2c = ConvBlock(64, seq='CBAC') - self.up_1r = ConvBlock(64, 64, seq='BAUBA') + self.up_1r = ConvBlock(64, seq='BAUBA') self.conv_1r = ConvBlock(128, 64, seq='CBAC') - self.up_0r = ConvBlock(64, 64, seq='BAUBA') + self.up_0r = ConvBlock(64, seq='BAUBA') self.conv_0r = ConvBlock(128, out_channels, seq='CAC') def forward(self, x):