From f5bd6576258d4e678ddb3e90280eff3e3514048e Mon Sep 17 00:00:00 2001 From: Yin Li Date: Fri, 26 Feb 2021 16:23:50 -0500 Subject: [PATCH] Duplicate vnet.py for future modification To incorporate styled convolution from style.py --- map2map/models/styled_vnet.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 map2map/models/styled_vnet.py diff --git a/map2map/models/styled_vnet.py b/map2map/models/styled_vnet.py new file mode 100644 index 0000000..9ef8475 --- /dev/null +++ b/map2map/models/styled_vnet.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn + +from .conv import ConvBlock, ResBlock +from .narrow import narrow_by + + +class StyledVNet(nn.Module): + def __init__(self, in_chan, out_chan, bypass=None, **kwargs): + """V-Net like network with styles + + See `vnet.VNet`. + """ + super().__init__() + + # activate non-identity skip connection in residual block + # by explicitly setting out_chan + self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA') + self.down_l0 = ConvBlock(64, seq='DBA') + self.conv_l1 = ResBlock(64, 64, seq='CBACBA') + self.down_l1 = ConvBlock(64, seq='DBA') + + self.conv_c = ResBlock(64, 64, seq='CBACBA') + + self.up_r1 = ConvBlock(64, seq='UBA') + self.conv_r1 = ResBlock(128, 64, seq='CBACBA') + self.up_r0 = ConvBlock(64, seq='UBA') + self.conv_r0 = ResBlock(128, out_chan, seq='CAC') + + self.bypass = in_chan == out_chan + + def forward(self, x): + if self.bypass: + x0 = x + + y0 = self.conv_l0(x) + x = self.down_l0(y0) + + y1 = self.conv_l1(x) + x = self.down_l1(y1) + + x = self.conv_c(x) + + x = self.up_r1(x) + y1 = narrow_by(y1, 4) + x = torch.cat([y1, x], dim=1) + del y1 + x = self.conv_r1(x) + + x = self.up_r0(x) + y0 = narrow_by(y0, 16) + x = torch.cat([y0, x], dim=1) + del y0 + x = self.conv_r0(x) + + if self.bypass: + x0 = narrow_by(x0, 20) + x += x0 + + return x