Make both fat and lean V-Net available
This commit is contained in:
parent
6e06682751
commit
1f89e894cc
@ -1,3 +1,3 @@
|
|||||||
from .unet import UNet
|
from .unet import UNet
|
||||||
from .vnet import VNet
|
from .vnet import VNet, VNetFat
|
||||||
from .conv import narrow_like
|
from .conv import narrow_like
|
||||||
|
@ -42,3 +42,58 @@ class VNet(nn.Module):
|
|||||||
x = self.conv_r0(x)
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VNetFat(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv_l0 = nn.Sequential(
|
||||||
|
ResBlock(in_channels, 64, seq='CACBA'),
|
||||||
|
ResBlock(64, seq='CBACBA'),
|
||||||
|
)
|
||||||
|
self.down_l0 = ConvBlock(64, seq='DBA')
|
||||||
|
self.conv_l1 = nn.Sequential(
|
||||||
|
ResBlock(64, seq='CBACBA'),
|
||||||
|
ResBlock(64, seq='CBACBA'),
|
||||||
|
) # FIXME: test CBACBA+DBA vs CBAC+BADBA
|
||||||
|
self.down_l1 = ConvBlock(64, seq='DBA')
|
||||||
|
|
||||||
|
self.conv_c = nn.Sequential(
|
||||||
|
ResBlock(64, seq='CBACBA'),
|
||||||
|
ResBlock(64, seq='CBACBA'),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_r1 = ConvBlock(64, seq='UBA')
|
||||||
|
self.conv_r1 = nn.Sequential(
|
||||||
|
ResBlock(128, seq='CBACBA'),
|
||||||
|
ResBlock(128, seq='CBACBA'),
|
||||||
|
)
|
||||||
|
self.up_r0 = ConvBlock(128, 64, seq='UBA')
|
||||||
|
self.conv_r0 = nn.Sequential(
|
||||||
|
ResBlock(128, seq='CBACBA'),
|
||||||
|
ResBlock(128, out_channels, seq='CAC'),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y0 = self.conv_l0(x)
|
||||||
|
x = self.down_l0(y0)
|
||||||
|
|
||||||
|
y1 = self.conv_l1(x)
|
||||||
|
x = self.down_l1(y1)
|
||||||
|
|
||||||
|
x = self.conv_c(x)
|
||||||
|
|
||||||
|
x = self.up_r1(x)
|
||||||
|
y1 = narrow_like(y1, x)
|
||||||
|
x = torch.cat([y1, x], dim=1)
|
||||||
|
del y1
|
||||||
|
x = self.conv_r1(x)
|
||||||
|
|
||||||
|
x = self.up_r0(x)
|
||||||
|
y0 = narrow_like(y0, x)
|
||||||
|
x = torch.cat([y0, x], dim=1)
|
||||||
|
del y0
|
||||||
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user