Add swish activation, kernel_size in ConvBlock, and optional trailing activation in ResBlock

This commit is contained in:
Yin Li 2019-12-12 15:30:38 -05:00
parent 341bdbff84
commit 946805c6be
3 changed files with 50 additions and 19 deletions

View File

@ -1,17 +1,24 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from .swish import Swish
class ConvBlock(nn.Module): class ConvBlock(nn.Module):
"""Convolution blocks of the form specified by `seq`. """Convolution blocks of the form specified by `seq`.
""" """
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBA'): def __init__(self, in_channels, out_channels=None, mid_channels=None,
kernel_size=3, seq='CBA'):
super().__init__() super().__init__()
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
if mid_channels is None: if mid_channels is None:
self.mid_channels = max(in_channels, out_channels) self.mid_channels = max(in_channels, out_channels)
self.kernel_size = kernel_size
self.bn_channels = in_channels self.bn_channels = in_channels
self.idx_conv = 0 self.idx_conv = 0
@ -30,11 +37,11 @@ class ConvBlock(nn.Module):
return nn.Conv3d(in_channels, out_channels, 2, stride=2) return nn.Conv3d(in_channels, out_channels, 2, stride=2)
elif l == 'C': elif l == 'C':
in_channels, out_channels = self._setup_conv() in_channels, out_channels = self._setup_conv()
return nn.Conv3d(in_channels, out_channels, 3) return nn.Conv3d(in_channels, out_channels, self.kernel_size)
elif l == 'B': elif l == 'B':
return nn.BatchNorm3d(self.bn_channels) return nn.BatchNorm3d(self.bn_channels)
elif l == 'A': elif l == 'A':
return nn.PReLU() return Swish()
else: else:
raise NotImplementedError('layer type {} not supported'.format(l)) raise NotImplementedError('layer type {} not supported'.format(l))
@ -56,13 +63,16 @@ class ConvBlock(nn.Module):
class ResBlock(ConvBlock): class ResBlock(ConvBlock):
"""Residual convolution blocks of the form specified by `seq`. Input is """Residual convolution blocks of the form specified by `seq`. Input is added
added to the residual followed by an activation (trailing `'A'` in `seq`). to the residual followed by an optional activation (trailing `'A'` in `seq`).
""" """
def __init__(self, in_channels, out_channels=None, mid_channels=None, def __init__(self, in_channels, out_channels=None, mid_channels=None,
seq='CBACBA'): seq='CBACBA'):
super().__init__(in_channels, out_channels=out_channels,
mid_channels=mid_channels,
seq=seq[:-1] if seq[-1] == 'A' else seq)
if out_channels is None: if out_channels is None:
out_channels = in_channels
self.skip = None self.skip = None
else: else:
self.skip = nn.Conv3d(in_channels, out_channels, 1) self.skip = nn.Conv3d(in_channels, out_channels, 1)
@ -71,12 +81,10 @@ class ResBlock(ConvBlock):
raise NotImplementedError('upsample and downsample layers ' raise NotImplementedError('upsample and downsample layers '
'not supported yet') 'not supported yet')
assert seq[-1] == 'A', 'block must end with activation' if seq[-1] == 'A':
self.act = Swish()
super().__init__(in_channels, out_channels, mid_channels=mid_channels, else:
seq=seq[:-1]) self.act = None
self.act = nn.PReLU()
def forward(self, x): def forward(self, x):
y = x y = x
@ -89,7 +97,10 @@ class ResBlock(ConvBlock):
y = narrow_like(y, x) y = narrow_like(y, x)
x += y x += y
return self.act(x) if self.act is not None:
x = self.act(x)
return x
def narrow_like(a, b): def narrow_like(a, b):

20
map2map/models/swish.py Normal file
View File

@ -0,0 +1,20 @@
import torch
class SwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
result = input * torch.sigmoid(input)
ctx.save_for_backward(input)
return result
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_variables
sigmoid = torch.sigmoid(input)
return grad_output * (sigmoid * (1 + input * (1 - sigmoid)))
class Swish(torch.nn.Module):
def forward(self, input):
return SwishFunction.apply(input)

View File

@ -9,15 +9,15 @@ class UNet(nn.Module):
super().__init__() super().__init__()
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC') self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
self.down_0l = ConvBlock(64, 64, seq='BADBA') self.down_0l = ConvBlock(64, seq='BADBA')
self.conv_1l = ConvBlock(64, 64, seq='CBAC') self.conv_1l = ConvBlock(64, seq='CBAC')
self.down_1l = ConvBlock(64, 64, seq='BADBA') self.down_1l = ConvBlock(64, seq='BADBA')
self.conv_2c = ConvBlock(64, 64, seq='CBAC') self.conv_2c = ConvBlock(64, seq='CBAC')
self.up_1r = ConvBlock(64, 64, seq='BAUBA') self.up_1r = ConvBlock(64, seq='BAUBA')
self.conv_1r = ConvBlock(128, 64, seq='CBAC') self.conv_1r = ConvBlock(128, 64, seq='CBAC')
self.up_0r = ConvBlock(64, 64, seq='BAUBA') self.up_0r = ConvBlock(64, seq='BAUBA')
self.conv_0r = ConvBlock(128, out_channels, seq='CAC') self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
def forward(self, x): def forward(self, x):