Add swish activation, kernel_size in ConvBlock, and optional trailing activation in ResBlock
This commit is contained in:
parent
341bdbff84
commit
946805c6be
@ -1,17 +1,24 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .swish import Swish
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolution blocks of the form specified by `seq`.
|
"""Convolution blocks of the form specified by `seq`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBA'):
|
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
||||||
|
kernel_size=3, seq='CBA'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
if out_channels is None:
|
||||||
|
out_channels = in_channels
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
if mid_channels is None:
|
if mid_channels is None:
|
||||||
self.mid_channels = max(in_channels, out_channels)
|
self.mid_channels = max(in_channels, out_channels)
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
self.bn_channels = in_channels
|
self.bn_channels = in_channels
|
||||||
self.idx_conv = 0
|
self.idx_conv = 0
|
||||||
@ -30,11 +37,11 @@ class ConvBlock(nn.Module):
|
|||||||
return nn.Conv3d(in_channels, out_channels, 2, stride=2)
|
return nn.Conv3d(in_channels, out_channels, 2, stride=2)
|
||||||
elif l == 'C':
|
elif l == 'C':
|
||||||
in_channels, out_channels = self._setup_conv()
|
in_channels, out_channels = self._setup_conv()
|
||||||
return nn.Conv3d(in_channels, out_channels, 3)
|
return nn.Conv3d(in_channels, out_channels, self.kernel_size)
|
||||||
elif l == 'B':
|
elif l == 'B':
|
||||||
return nn.BatchNorm3d(self.bn_channels)
|
return nn.BatchNorm3d(self.bn_channels)
|
||||||
elif l == 'A':
|
elif l == 'A':
|
||||||
return nn.PReLU()
|
return Swish()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('layer type {} not supported'.format(l))
|
raise NotImplementedError('layer type {} not supported'.format(l))
|
||||||
|
|
||||||
@ -56,13 +63,16 @@ class ConvBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResBlock(ConvBlock):
|
class ResBlock(ConvBlock):
|
||||||
"""Residual convolution blocks of the form specified by `seq`. Input is
|
"""Residual convolution blocks of the form specified by `seq`. Input is added
|
||||||
added to the residual followed by an activation (trailing `'A'` in `seq`).
|
to the residual followed by an optional activation (trailing `'A'` in `seq`).
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
def __init__(self, in_channels, out_channels=None, mid_channels=None,
|
||||||
seq='CBACBA'):
|
seq='CBACBA'):
|
||||||
|
super().__init__(in_channels, out_channels=out_channels,
|
||||||
|
mid_channels=mid_channels,
|
||||||
|
seq=seq[:-1] if seq[-1] == 'A' else seq)
|
||||||
|
|
||||||
if out_channels is None:
|
if out_channels is None:
|
||||||
out_channels = in_channels
|
|
||||||
self.skip = None
|
self.skip = None
|
||||||
else:
|
else:
|
||||||
self.skip = nn.Conv3d(in_channels, out_channels, 1)
|
self.skip = nn.Conv3d(in_channels, out_channels, 1)
|
||||||
@ -71,12 +81,10 @@ class ResBlock(ConvBlock):
|
|||||||
raise NotImplementedError('upsample and downsample layers '
|
raise NotImplementedError('upsample and downsample layers '
|
||||||
'not supported yet')
|
'not supported yet')
|
||||||
|
|
||||||
assert seq[-1] == 'A', 'block must end with activation'
|
if seq[-1] == 'A':
|
||||||
|
self.act = Swish()
|
||||||
super().__init__(in_channels, out_channels, mid_channels=mid_channels,
|
else:
|
||||||
seq=seq[:-1])
|
self.act = None
|
||||||
|
|
||||||
self.act = nn.PReLU()
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x
|
y = x
|
||||||
@ -89,7 +97,10 @@ class ResBlock(ConvBlock):
|
|||||||
y = narrow_like(y, x)
|
y = narrow_like(y, x)
|
||||||
x += y
|
x += y
|
||||||
|
|
||||||
return self.act(x)
|
if self.act is not None:
|
||||||
|
x = self.act(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def narrow_like(a, b):
|
def narrow_like(a, b):
|
||||||
|
20
map2map/models/swish.py
Normal file
20
map2map/models/swish.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SwishFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input):
|
||||||
|
result = input * torch.sigmoid(input)
|
||||||
|
ctx.save_for_backward(input)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input, = ctx.saved_variables
|
||||||
|
sigmoid = torch.sigmoid(input)
|
||||||
|
return grad_output * (sigmoid * (1 + input * (1 - sigmoid)))
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(torch.nn.Module):
|
||||||
|
def forward(self, input):
|
||||||
|
return SwishFunction.apply(input)
|
@ -9,15 +9,15 @@ class UNet(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
|
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
|
||||||
self.down_0l = ConvBlock(64, 64, seq='BADBA')
|
self.down_0l = ConvBlock(64, seq='BADBA')
|
||||||
self.conv_1l = ConvBlock(64, 64, seq='CBAC')
|
self.conv_1l = ConvBlock(64, seq='CBAC')
|
||||||
self.down_1l = ConvBlock(64, 64, seq='BADBA')
|
self.down_1l = ConvBlock(64, seq='BADBA')
|
||||||
|
|
||||||
self.conv_2c = ConvBlock(64, 64, seq='CBAC')
|
self.conv_2c = ConvBlock(64, seq='CBAC')
|
||||||
|
|
||||||
self.up_1r = ConvBlock(64, 64, seq='BAUBA')
|
self.up_1r = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_1r = ConvBlock(128, 64, seq='CBAC')
|
self.conv_1r = ConvBlock(128, 64, seq='CBAC')
|
||||||
self.up_0r = ConvBlock(64, 64, seq='BAUBA')
|
self.up_0r = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
|
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
Loading…
Reference in New Issue
Block a user