import torch.nn as nn import torch from map2map.models.styled_conv import ConvStyledBlock, ResStyledBlock from map2map.models.narrow import narrow_by, narrow_like import math import numpy as np class SimpleStyledVNet(nn.Module): def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs): """V-Net like network with styles See `vnet.VNet`. """ super().__init__() self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='C') self.down_l0 = ConvStyledBlock(style_size, 64, seq='BA') self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='C') self.down_l1 = ConvStyledBlock(style_size, 64, seq='BA') self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA') self.up_r1 = ConvStyledBlock(style_size, 64, seq='BA') self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='C') self.up_r0 = ConvStyledBlock(style_size, 64, seq='BA') self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='C') if bypass is None: self.bypass = in_chan == out_chan else: self.bypass = bypass def forward(self, x, s): if self.bypass: x0 = x y0 = self.conv_l0(x, s) x = self.down_l0(y0, s) y1 = self.conv_l1(x, s) x = self.down_l1(y1, s) x = self.conv_c(x, s) x = self.up_r1(x, s) y1 = narrow_by(y1, int((y1.shape[2] - x.shape[2])/2)) x = torch.cat([y1, x], dim=1) del y1 x = self.conv_r1(x, s) x = self.up_r0(x, s) y0 = narrow_by(y0, int((y0.shape[2] - x.shape[2])/2)) x = torch.cat([y0, x], dim=1) del y0 x = self.conv_r0(x, s) if self.bypass: x0 = narrow_by(x0, 20) x += x0 return x class StyledVNet(nn.Module): def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs): """V-Net like network with styles See `vnet.VNet`. """ super().__init__() # activate non-identity skip connection in residual block # by explicitly setting out_chan self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='C') self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA') self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA') self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA') self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA') self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA') self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA') self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA') self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC') if bypass is None: self.bypass = in_chan == out_chan else: self.bypass = bypass def forward(self, x, s): if self.bypass: x0 = x y0 = self.conv_l0(x, s) x = self.down_l0(y0, s) y1 = self.conv_l1(x, s) x = self.down_l1(y1, s) x = self.conv_c(x, s) x = self.up_r1(x, s) y1 = narrow_like(y1, x) x = torch.cat([y1, x], dim=1) del y1 x = self.conv_r1(x, s) x = self.up_r0(x, s) y0 = narrow_like(y0, x) x = torch.cat([y0, x], dim=1) del y0 x = self.conv_r0(x, s) if self.bypass: x0 = narrow_by(x0, 20) x += x0 return x class StyledVNetD2(nn.Module): def __init__(self, style_size, in_chan, out_chan, bypass=None, ch_base=64, **kwargs): """V-Net like network with styles. See `vnet.VNet`. Args: style_size (int): Size of the style vector. in_chan (int): Number of input channels. out_chan (int): Number of output channels. bypass (bool, optional): Enable or disable bypass connection if in_chan equals out_chan. ch_base (int): Base channel number for the initial layers. ch_l1 (int): Channel number for mid-level layers. ch_out (int): Output channels before the final layer. """ super(StyledVNetD2, self).__init__() # Define the padding self.in_pad = 24 # Adjusted channel numbers in layer definitions self.conv_l00 = ResStyledBlock(style_size, in_chan, ch_base, seq='CBACBA') self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') self.down_l0 = ConvStyledBlock(style_size, ch_base, seq='DBA') self.conv_l1 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') self.down_l1 = ConvStyledBlock(style_size, ch_base, seq='DBA') self.conv_c = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') self.up_r1 = ConvStyledBlock(style_size, ch_base, seq='UBA') self.conv_r1 = ResStyledBlock(style_size, ch_base * 2, ch_base, seq='CBACBA') self.up_r0 = ConvStyledBlock(style_size, ch_base, seq='UBA') self.conv_r00 = ResStyledBlock(style_size, ch_base * 2, ch_base, seq='CBACBA') self.conv_r01 = ResStyledBlock(style_size, ch_base, out_chan, seq='CBAC') self.bypass = bypass if bypass is not None else in_chan == out_chan def forward(self, x, s): if self.bypass: x0 = x x = self.conv_l00(x, s) y0 = self.conv_l01(x, s) x = self.down_l0(y0, s) y1 = self.conv_l1(x, s) x = self.down_l1(y1, s) x = self.conv_c(x, s) x = self.up_r1(x, s) y1 = narrow_like(y1, x) x = torch.cat([y1, x], dim=1) del y1 x = self.conv_r1(x, s) x = self.up_r0(x, s) y0 = narrow_like(y0, x) x = torch.cat([y0, x], dim=1) del y0 x = self.conv_r00(x, s) x = self.conv_r01(x, s) if self.bypass: x0 = narrow_like(x0, x) x += x0 return x class StyledVNetD3(nn.Module): def __init__(self, style_size, in_chan, out_chan, ch_base=64, bypass=None, **kwargs): """V-Net like network with styles See `vnet.VNet`. """ super().__init__() # Define the padding self.in_pad = 48 # activate non-identity skip connection in residual block # by explicitly setting out_chan self.conv_l00 = ResStyledBlock(style_size, in_chan, 64, seq='CBACBA') self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') self.down_l0 = ConvStyledBlock(style_size, ch_base, seq='DBA') self.conv_l1 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') self.down_l1 = ConvStyledBlock(style_size, ch_base, seq='DBA') self.conv_l2 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') self.down_l2 = ConvStyledBlock(style_size, ch_base, seq='DBA') self.conv_c = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') self.up_r2 = ConvStyledBlock(style_size, ch_base, seq='UBA') self.conv_r2 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA') self.up_r1 = ConvStyledBlock(style_size, ch_base, seq='UBA') self.conv_r1 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA') self.up_r0 = ConvStyledBlock(style_size, ch_base, seq='UBA') self.conv_r00 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA') self.conv_r01 = ResStyledBlock(style_size, ch_base, out_chan, seq='CBAC') if bypass is None: self.bypass = in_chan == out_chan else: self.bypass = bypass def forward(self, x, s): if self.bypass: x0 = x x = self.conv_l00(x, s) y0 = self.conv_l01(x, s) x = self.down_l0(y0, s) y1 = self.conv_l1(x, s) x = self.down_l1(y1, s) y2 = self.conv_l2(x, s) x = self.down_l2(y2, s) x = self.conv_c(x, s) x = self.up_r2(x, s) y2 = narrow_like(y2, x) x = torch.cat([y2, x], dim=1) del y2 x = self.conv_r2(x, s) x = self.up_r1(x, s) y1 = narrow_like(y1, x) x = torch.cat([y1, x], dim=1) del y1 x = self.conv_r1(x, s) x = self.up_r0(x, s) y0 = narrow_like(y0, x) x = torch.cat([y0, x], dim=1) del y0 x = self.conv_r00(x, s) x = self.conv_r01(x, s) if self.bypass: x0 = narrow_like(x0, x) x += x0 return x class StyledVNetD4(nn.Module): def __init__(self, style_size, in_chan, out_chan, ch_base=32, bypass=None, **kwargs): """V-Net like network with styles See `vnet.VNet`. """ super().__init__() # Define the padding self.in_pad = 48 # activate non-identity skip connection in residual block # by explicitly setting out_chan self.conv_l00 = ResStyledBlock(style_size, in_chan, ch_base, seq='CACA') self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') self.down_l0 = ConvStyledBlock(style_size, ch_base, ch_base*2, seq='DA') self.conv_l1 = ResStyledBlock(style_size, ch_base*2, ch_base*2, seq='CBACBA') self.down_l1 = ConvStyledBlock(style_size, ch_base*2, ch_base*4, seq='DA') self.conv_l2 = ResStyledBlock(style_size, ch_base*4, ch_base*4, seq='CBACBA') self.down_l2 = ConvStyledBlock(style_size, ch_base*4, ch_base*8, seq='DA') self.conv_c = ResStyledBlock(style_size, ch_base*8, ch_base*8, seq='CACA') self.up_r2 = ConvStyledBlock(style_size, ch_base*8, ch_base*4, seq='UA') self.conv_r2 = ResStyledBlock(style_size, ch_base*8, ch_base*4, seq='CBACBA') self.up_r1 = ConvStyledBlock(style_size, ch_base*4, ch_base*2, seq='UA') self.conv_r1 = ResStyledBlock(style_size, ch_base*4, ch_base*2, seq='CBACBA') self.up_r0 = ConvStyledBlock(style_size, ch_base*4, ch_base*2, seq='UA') self.conv_r00 = ResStyledBlock(style_size, ch_base*4, ch_base*2, seq='CBACBA') self.conv_r01 = ResStyledBlock(style_size, ch_base*2, out_chan, seq='CBAC') if bypass is None: self.bypass = in_chan == out_chan else: self.bypass = bypass def forward(self, x, s): if self.bypass: x0 = x x = self.conv_l00(x, s) y0 = self.conv_l01(x, s) x = self.down_l0(y0, s) y1 = self.conv_l1(x, s) x = self.down_l1(y1, s) y2 = self.conv_l2(x, s) x = self.down_l2(y2, s) x = self.conv_c(x, s) x = self.up_r2(x, s) y2 = narrow_like(y2, x) x = torch.cat([y2, x], dim=1) del y2 x = self.conv_r2(x, s) x = self.up_r1(x, s) y1 = narrow_like(y1, x) x = torch.cat([y1, x], dim=1) del y1 x = self.conv_r1(x, s) x = self.up_r0(x, s) y0 = narrow_like(y0, x) x = torch.cat([y0, x], dim=1) del y0 x = self.conv_r00(x, s) x = self.conv_r01(x, s) if self.bypass: x0 = narrow_like(x0, x) x += x0 return x