Add enhanced model
Add ResBlock using ConvBlock; Use ResBlock in UNet; Use PReLU instead of LeakyReLU; Add more channels in lower levels in UNet; Add more blocks in each level in UNet;
This commit is contained in:
parent
88bfd11594
commit
9d4b5daae3
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||||||
class ConvBlock(nn.Module):
|
class ConvBlock(nn.Module):
|
||||||
"""Convolution blocks of the form specified by `seq`.
|
"""Convolution blocks of the form specified by `seq`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBAC'):
|
def __init__(self, in_channels, out_channels, mid_channels=None, seq='CBA'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@ -34,7 +34,7 @@ class ConvBlock(nn.Module):
|
|||||||
elif l == 'B':
|
elif l == 'B':
|
||||||
return nn.BatchNorm3d(self.bn_channels)
|
return nn.BatchNorm3d(self.bn_channels)
|
||||||
elif l == 'A':
|
elif l == 'A':
|
||||||
return nn.LeakyReLU(inplace=True)
|
return nn.PReLU()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('layer type {} not supported'.format(l))
|
raise NotImplementedError('layer type {} not supported'.format(l))
|
||||||
|
|
||||||
@ -55,6 +55,26 @@ class ConvBlock(nn.Module):
|
|||||||
return self.convs(x)
|
return self.convs(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(ConvBlock):
|
||||||
|
"""Residual convolution blocks of the form specified by `seq`. Input is
|
||||||
|
added to the residual followed by an activation.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, seq='CBACB'):
|
||||||
|
super().__init__(in_channels=channels, out_channels=channels, seq=seq)
|
||||||
|
|
||||||
|
self.act = nn.PReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
|
||||||
|
x = self.convs(x)
|
||||||
|
|
||||||
|
y = narrow_like(y, x)
|
||||||
|
x += y
|
||||||
|
|
||||||
|
return self.act(x)
|
||||||
|
|
||||||
|
|
||||||
def narrow_like(a, b):
|
def narrow_like(a, b):
|
||||||
"""Narrow a to be like b.
|
"""Narrow a to be like b.
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock, narrow_like
|
from .conv import ConvBlock, ResBlock, narrow_like
|
||||||
|
|
||||||
|
|
||||||
class DownBlock(ConvBlock):
|
class DownBlock(ConvBlock):
|
||||||
@ -17,17 +17,32 @@ class UNet(nn.Module):
|
|||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
|
self.conv_0l = nn.Sequential(
|
||||||
self.down_0l = DownBlock(64, 64)
|
ConvBlock(in_channels, 64, seq='CA'),
|
||||||
self.conv_1l = ConvBlock(64, 64)
|
ResBlock(64, seq='CBACBACB'),
|
||||||
self.down_1l = DownBlock(64, 64)
|
)
|
||||||
|
self.down_0l = ConvBlock(64, 128, seq='DBA')
|
||||||
|
self.conv_1l = nn.Sequential(
|
||||||
|
ResBlock(128, seq='CBACB'),
|
||||||
|
ResBlock(128, seq='CBACB'),
|
||||||
|
)
|
||||||
|
self.down_1l = ConvBlock(128, 256, seq='DBA')
|
||||||
|
|
||||||
self.conv_2c = ConvBlock(64, 64)
|
self.conv_2c = nn.Sequential(
|
||||||
|
ResBlock(256, seq='CBACB'),
|
||||||
|
ResBlock(256, seq='CBACB'),
|
||||||
|
)
|
||||||
|
|
||||||
self.up_1r = UpBlock(64, 64)
|
self.up_1r = ConvBlock(256, 128, seq='UBA')
|
||||||
self.conv_1r = ConvBlock(128, 64)
|
self.conv_1r = nn.Sequential(
|
||||||
self.up_0r = UpBlock(64, 64)
|
ResBlock(256, seq='CBACB'),
|
||||||
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
|
ResBlock(256, seq='CBACB'),
|
||||||
|
)
|
||||||
|
self.up_0r = ConvBlock(256, 64, seq='UBA')
|
||||||
|
self.conv_0r = nn.Sequential(
|
||||||
|
ResBlock(128, seq='CBACBAC'),
|
||||||
|
ConvBlock(128, out_channels, seq='C')
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y0 = self.conv_0l(x)
|
y0 = self.conv_0l(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user