Add enhanced model
Add ResBlock using ConvBlock; Use ResBlock in UNet; Use PReLU instead of LeakyReLU; Add more channels in lower levels in UNet; Add more blocks in each level in UNet;
This commit is contained in:
parent
88bfd11594
commit
9d4b5daae3
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
class ConvBlock(nn.Module):
|
||||
"""Convolution blocks of the form specified by `seq`.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBAC'):
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBA'):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
@ -34,7 +34,7 @@ class ConvBlock(nn.Module):
|
||||
elif l == 'B':
|
||||
return nn.BatchNorm3d(self.bn_channels)
|
||||
elif l == 'A':
|
||||
return nn.LeakyReLU(inplace=True)
|
||||
return nn.PReLU()
|
||||
else:
|
||||
raise NotImplementedError('layer type {} not supported'.format(l))
|
||||
|
||||
@ -55,6 +55,26 @@ class ConvBlock(nn.Module):
|
||||
return self.convs(x)
|
||||
|
||||
|
||||
class ResBlock(ConvBlock):
|
||||
"""Residual convolution blocks of the form specified by `seq`. Input is
|
||||
added to the residual followed by an activation.
|
||||
"""
|
||||
def __init__(self, channels, seq='CBACB'):
|
||||
super().__init__(in_channels=channels, out_channels=channels, seq=seq)
|
||||
|
||||
self.act = nn.PReLU()
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
|
||||
x = self.convs(x)
|
||||
|
||||
y = narrow_like(y, x)
|
||||
x += y
|
||||
|
||||
return self.act(x)
|
||||
|
||||
|
||||
def narrow_like(a, b):
|
||||
"""Narrow a to be like b.
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import ConvBlock, narrow_like
|
||||
from .conv import ConvBlock, ResBlock, narrow_like
|
||||
|
||||
|
||||
class DownBlock(ConvBlock):
|
||||
@ -17,17 +17,32 @@ class UNet(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
|
||||
self.down_0l = DownBlock(64, 64)
|
||||
self.conv_1l = ConvBlock(64, 64)
|
||||
self.down_1l = DownBlock(64, 64)
|
||||
self.conv_0l = nn.Sequential(
|
||||
ConvBlock(in_channels, 64, seq='CA'),
|
||||
ResBlock(64, seq='CBACBACB'),
|
||||
)
|
||||
self.down_0l = ConvBlock(64, 128, seq='DBA')
|
||||
self.conv_1l = nn.Sequential(
|
||||
ResBlock(128, seq='CBACB'),
|
||||
ResBlock(128, seq='CBACB'),
|
||||
)
|
||||
self.down_1l = ConvBlock(128, 256, seq='DBA')
|
||||
|
||||
self.conv_2c = ConvBlock(64, 64)
|
||||
self.conv_2c = nn.Sequential(
|
||||
ResBlock(256, seq='CBACB'),
|
||||
ResBlock(256, seq='CBACB'),
|
||||
)
|
||||
|
||||
self.up_1r = UpBlock(64, 64)
|
||||
self.conv_1r = ConvBlock(128, 64)
|
||||
self.up_0r = UpBlock(64, 64)
|
||||
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
|
||||
self.up_1r = ConvBlock(256, 128, seq='UBA')
|
||||
self.conv_1r = nn.Sequential(
|
||||
ResBlock(256, seq='CBACB'),
|
||||
ResBlock(256, seq='CBACB'),
|
||||
)
|
||||
self.up_0r = ConvBlock(256, 64, seq='UBA')
|
||||
self.conv_0r = nn.Sequential(
|
||||
ResBlock(128, seq='CBACBAC'),
|
||||
ConvBlock(128, out_channels, seq='C')
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y0 = self.conv_0l(x)
|
||||
|
Loading…
Reference in New Issue
Block a user