From 1f89e894cc13b28f15c7713c577a651f7ae75279 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Tue, 21 Jan 2020 15:22:42 -0500 Subject: [PATCH] Make both fat and lean V-Net available --- map2map/models/__init__.py | 2 +- map2map/models/vnet.py | 55 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/map2map/models/__init__.py b/map2map/models/__init__.py index 6dc7b5c..86f733e 100644 --- a/map2map/models/__init__.py +++ b/map2map/models/__init__.py @@ -1,3 +1,3 @@ from .unet import UNet -from .vnet import VNet +from .vnet import VNet, VNetFat from .conv import narrow_like diff --git a/map2map/models/vnet.py b/map2map/models/vnet.py index e8bb8d5..0c089b8 100644 --- a/map2map/models/vnet.py +++ b/map2map/models/vnet.py @@ -42,3 +42,58 @@ class VNet(nn.Module): x = self.conv_r0(x) return x + + +class VNetFat(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.conv_l0 = nn.Sequential( + ResBlock(in_channels, 64, seq='CACBA'), + ResBlock(64, seq='CBACBA'), + ) + self.down_l0 = ConvBlock(64, seq='DBA') + self.conv_l1 = nn.Sequential( + ResBlock(64, seq='CBACBA'), + ResBlock(64, seq='CBACBA'), + ) # FIXME: test CBACBA+DBA vs CBAC+BADBA + self.down_l1 = ConvBlock(64, seq='DBA') + + self.conv_c = nn.Sequential( + ResBlock(64, seq='CBACBA'), + ResBlock(64, seq='CBACBA'), + ) + + self.up_r1 = ConvBlock(64, seq='UBA') + self.conv_r1 = nn.Sequential( + ResBlock(128, seq='CBACBA'), + ResBlock(128, seq='CBACBA'), + ) + self.up_r0 = ConvBlock(128, 64, seq='UBA') + self.conv_r0 = nn.Sequential( + ResBlock(128, seq='CBACBA'), + ResBlock(128, out_channels, seq='CAC'), + ) + + def forward(self, x): + y0 = self.conv_l0(x) + x = self.down_l0(y0) + + y1 = self.conv_l1(x) + x = self.down_l1(y1) + + x = self.conv_c(x) + + x = self.up_r1(x) + y1 = narrow_like(y1, x) + x = torch.cat([y1, x], dim=1) + del y1 + x = self.conv_r1(x) + + x = self.up_r0(x) + y0 = narrow_like(y0, x) + x = torch.cat([y0, x], dim=1) + del y0 + x = self.conv_r0(x) + + return x