From 94ce018cb84fe07c17a0603579f3495d1d26cc3f Mon Sep 17 00:00:00 2001 From: Yin Li Date: Mon, 20 Jan 2020 21:49:01 -0500 Subject: [PATCH] Change U-Net and V-Net inner naming convention; Slim down V-Net --- map2map/models/unet.py | 36 ++++++++++++++--------------- map2map/models/vnet.py | 51 +++++++++++++++--------------------------- 2 files changed, 36 insertions(+), 51 deletions(-) diff --git a/map2map/models/unet.py b/map2map/models/unet.py index 5ae9c6e..70c122a 100644 --- a/map2map/models/unet.py +++ b/map2map/models/unet.py @@ -8,37 +8,37 @@ class UNet(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - self.conv_0l = ConvBlock(in_channels, 64, seq='CAC') - self.down_0l = ConvBlock(64, seq='BADBA') - self.conv_1l = ConvBlock(64, seq='CBAC') - self.down_1l = ConvBlock(64, seq='BADBA') + self.conv_l0 = ConvBlock(in_channels, 64, seq='CAC') + self.down_l0 = ConvBlock(64, seq='BADBA') + self.conv_l1 = ConvBlock(64, seq='CBAC') + self.down_l1 = ConvBlock(64, seq='BADBA') - self.conv_2c = ConvBlock(64, seq='CBAC') + self.conv_c = ConvBlock(64, seq='CBAC') - self.up_1r = ConvBlock(64, seq='BAUBA') - self.conv_1r = ConvBlock(128, 64, seq='CBAC') - self.up_0r = ConvBlock(64, seq='BAUBA') - self.conv_0r = ConvBlock(128, out_channels, seq='CAC') + self.up_r1 = ConvBlock(64, seq='BAUBA') + self.conv_r1 = ConvBlock(128, 64, seq='CBAC') + self.up_r0 = ConvBlock(64, seq='BAUBA') + self.conv_r0 = ConvBlock(128, out_channels, seq='CAC') def forward(self, x): - y0 = self.conv_0l(x) - x = self.down_0l(y0) + y0 = self.conv_l0(x) + x = self.down_l0(y0) - y1 = self.conv_1l(x) - x = self.down_1l(y1) + y1 = self.conv_l1(x) + x = self.down_l1(y1) - x = self.conv_2c(x) + x = self.conv_c(x) - x = self.up_1r(x) + x = self.up_r1(x) y1 = narrow_like(y1, x) x = torch.cat([y1, x], dim=1) del y1 - x = self.conv_1r(x) + x = self.conv_r1(x) - x = self.up_0r(x) + x = self.up_r0(x) y0 = narrow_like(y0, x) x = torch.cat([y0, x], dim=1) del y0 - x = self.conv_0r(x) + x = self.conv_r0(x) return x diff --git a/map2map/models/vnet.py b/map2map/models/vnet.py index 84ffa3e..e8bb8d5 100644 --- a/map2map/models/vnet.py +++ b/map2map/models/vnet.py @@ -8,52 +8,37 @@ class VNet(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - self.conv_0l = nn.Sequential( - ConvBlock(in_channels, 64, seq='CA'), - ResBlock(64, seq='CBACBACBA'), - ) - self.down_0l = ConvBlock(64, 128, seq='DBA') - self.conv_1l = nn.Sequential( - ResBlock(128, seq='CBACBA'), - ResBlock(128, seq='CBACBA'), - ) - self.down_1l = ConvBlock(128, 256, seq='DBA') + self.conv_l0 = ResBlock(in_channels, 64, seq='CAC') + self.down_l0 = ConvBlock(64, seq='BADBA') + self.conv_l1 = ResBlock(64, seq='CBAC') + self.down_l1 = ConvBlock(64, seq='BADBA') - self.conv_2c = nn.Sequential( - ResBlock(256, seq='CBACBA'), - ResBlock(256, seq='CBACBA'), - ) + self.conv_c = ResBlock(64, seq='CBAC') - self.up_1r = ConvBlock(256, 128, seq='UBA') - self.conv_1r = nn.Sequential( - ResBlock(256, seq='CBACBA'), - ResBlock(256, seq='CBACBA'), - ) - self.up_0r = ConvBlock(256, 64, seq='UBA') - self.conv_0r = nn.Sequential( - ResBlock(128, seq='CBACBACA'), - ConvBlock(128, out_channels, seq='C') - ) + self.up_r1 = ConvBlock(64, seq='BAUBA') + self.conv_r1 = ResBlock(128, 64, seq='CBAC') + self.up_r0 = ConvBlock(64, seq='BAUBA') + self.conv_r0 = ResBlock(128, out_channels, seq='CAC') def forward(self, x): - y0 = self.conv_0l(x) - x = self.down_0l(y0) + y0 = self.conv_l0(x) + x = self.down_l0(y0) - y1 = self.conv_1l(x) - x = self.down_1l(y1) + y1 = self.conv_l1(x) + x = self.down_l1(y1) - x = self.conv_2c(x) + x = self.conv_c(x) - x = self.up_1r(x) + x = self.up_r1(x) y1 = narrow_like(y1, x) x = torch.cat([y1, x], dim=1) del y1 - x = self.conv_1r(x) + x = self.conv_r1(x) - x = self.up_0r(x) + x = self.up_r0(x) y0 = narrow_like(y0, x) x = torch.cat([y0, x], dim=1) del y0 - x = self.conv_0r(x) + x = self.conv_r0(x) return x