Add enhanced model

Add ResBlock using ConvBlock;
Use ResBlock in UNet;
Use PReLU instead of LeakyReLU;
Add more channels in lower levels in UNet;
Add more blocks in each level in UNet;
This commit is contained in:
Yin Li 2019-11-30 21:49:10 -05:00
parent 88bfd11594
commit 9d4b5daae3
2 changed files with 47 additions and 12 deletions

View File

@ -5,7 +5,7 @@ import torch.nn as nn
class ConvBlock(nn.Module):
"""Convolution blocks of the form specified by `seq`.
"""
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBAC'):
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBA'):
super().__init__()
self.in_channels = in_channels
@ -34,7 +34,7 @@ class ConvBlock(nn.Module):
elif l == 'B':
return nn.BatchNorm3d(self.bn_channels)
elif l == 'A':
return nn.LeakyReLU(inplace=True)
return nn.PReLU()
else:
raise NotImplementedError('layer type {} not supported'.format(l))
@ -55,6 +55,26 @@ class ConvBlock(nn.Module):
return self.convs(x)
class ResBlock(ConvBlock):
"""Residual convolution blocks of the form specified by `seq`. Input is
added to the residual followed by an activation.
"""
def __init__(self, channels, seq='CBACB'):
super().__init__(in_channels=channels, out_channels=channels, seq=seq)
self.act = nn.PReLU()
def forward(self, x):
y = x
x = self.convs(x)
y = narrow_like(y, x)
x += y
return self.act(x)
def narrow_like(a, b):
"""Narrow a to be like b.

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, narrow_like
from .conv import ConvBlock, ResBlock, narrow_like
class DownBlock(ConvBlock):
@ -17,17 +17,32 @@ class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
self.down_0l = DownBlock(64, 64)
self.conv_1l = ConvBlock(64, 64)
self.down_1l = DownBlock(64, 64)
self.conv_0l = nn.Sequential(
ConvBlock(in_channels, 64, seq='CA'),
ResBlock(64, seq='CBACBACB'),
)
self.down_0l = ConvBlock(64, 128, seq='DBA')
self.conv_1l = nn.Sequential(
ResBlock(128, seq='CBACB'),
ResBlock(128, seq='CBACB'),
)
self.down_1l = ConvBlock(128, 256, seq='DBA')
self.conv_2c = ConvBlock(64, 64)
self.conv_2c = nn.Sequential(
ResBlock(256, seq='CBACB'),
ResBlock(256, seq='CBACB'),
)
self.up_1r = UpBlock(64, 64)
self.conv_1r = ConvBlock(128, 64)
self.up_0r = UpBlock(64, 64)
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
self.up_1r = ConvBlock(256, 128, seq='UBA')
self.conv_1r = nn.Sequential(
ResBlock(256, seq='CBACB'),
ResBlock(256, seq='CBACB'),
)
self.up_0r = ConvBlock(256, 64, seq='UBA')
self.conv_0r = nn.Sequential(
ResBlock(128, seq='CBACBAC'),
ConvBlock(128, out_channels, seq='C')
)
def forward(self, x):
y0 = self.conv_0l(x)