From 36f38ef256a8cf90dd26bddea717691d817d2944 Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Wed, 25 Jun 2025 17:00:23 +0200 Subject: [PATCH] copied COCA preexisting models --- sCOCA_ML/models/StyledVNet_models.py | 343 +++++++++++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 sCOCA_ML/models/StyledVNet_models.py diff --git a/sCOCA_ML/models/StyledVNet_models.py b/sCOCA_ML/models/StyledVNet_models.py new file mode 100644 index 0000000..bd862d2 --- /dev/null +++ b/sCOCA_ML/models/StyledVNet_models.py @@ -0,0 +1,343 @@ +import torch.nn as nn +import torch +from map2map.models.styled_conv import ConvStyledBlock, ResStyledBlock +from map2map.models.narrow import narrow_by, narrow_like +import math +import numpy as np + +class SimpleStyledVNet(nn.Module): + def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs): + """V-Net like network with styles + + See `vnet.VNet`. + """ + super().__init__() + + self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='C') + self.down_l0 = ConvStyledBlock(style_size, 64, seq='BA') + self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='C') + self.down_l1 = ConvStyledBlock(style_size, 64, seq='BA') + + self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA') + + self.up_r1 = ConvStyledBlock(style_size, 64, seq='BA') + self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='C') + self.up_r0 = ConvStyledBlock(style_size, 64, seq='BA') + self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='C') + + if bypass is None: + self.bypass = in_chan == out_chan + else: + self.bypass = bypass + + def forward(self, x, s): + if self.bypass: + x0 = x + + y0 = self.conv_l0(x, s) + x = self.down_l0(y0, s) + + y1 = self.conv_l1(x, s) + x = self.down_l1(y1, s) + + x = self.conv_c(x, s) + + x = self.up_r1(x, s) + y1 = narrow_by(y1, int((y1.shape[2] - x.shape[2])/2)) + x = torch.cat([y1, x], dim=1) + del y1 + x = self.conv_r1(x, s) + + x = self.up_r0(x, s) + y0 = narrow_by(y0, int((y0.shape[2] - x.shape[2])/2)) + x = torch.cat([y0, x], dim=1) + del y0 + x = self.conv_r0(x, s) + + if self.bypass: + x0 = narrow_by(x0, 20) + x += x0 + + return x + + + +class StyledVNet(nn.Module): + def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs): + """V-Net like network with styles + + See `vnet.VNet`. + """ + super().__init__() + + # activate non-identity skip connection in residual block + # by explicitly setting out_chan + self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='C') + self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA') + self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA') + self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA') + + self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA') + + self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA') + self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA') + self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA') + self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC') + + if bypass is None: + self.bypass = in_chan == out_chan + else: + self.bypass = bypass + + def forward(self, x, s): + if self.bypass: + x0 = x + + y0 = self.conv_l0(x, s) + x = self.down_l0(y0, s) + + y1 = self.conv_l1(x, s) + x = self.down_l1(y1, s) + + x = self.conv_c(x, s) + + x = self.up_r1(x, s) + y1 = narrow_like(y1, x) + x = torch.cat([y1, x], dim=1) + del y1 + x = self.conv_r1(x, s) + + x = self.up_r0(x, s) + y0 = narrow_like(y0, x) + x = torch.cat([y0, x], dim=1) + del y0 + x = self.conv_r0(x, s) + + if self.bypass: + x0 = narrow_by(x0, 20) + x += x0 + + return x + + +class StyledVNetD2(nn.Module): + def __init__(self, style_size, in_chan, out_chan, bypass=None, ch_base=64, **kwargs): + """V-Net like network with styles. + See `vnet.VNet`. + + Args: + style_size (int): Size of the style vector. + in_chan (int): Number of input channels. + out_chan (int): Number of output channels. + bypass (bool, optional): Enable or disable bypass connection if in_chan equals out_chan. + ch_base (int): Base channel number for the initial layers. + ch_l1 (int): Channel number for mid-level layers. + ch_out (int): Output channels before the final layer. + """ + super(StyledVNetD2, self).__init__() + + # Define the padding + self.in_pad = 24 + + # Adjusted channel numbers in layer definitions + self.conv_l00 = ResStyledBlock(style_size, in_chan, ch_base, seq='CBACBA') + self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') + self.down_l0 = ConvStyledBlock(style_size, ch_base, seq='DBA') + self.conv_l1 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') + self.down_l1 = ConvStyledBlock(style_size, ch_base, seq='DBA') + + self.conv_c = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') + + self.up_r1 = ConvStyledBlock(style_size, ch_base, seq='UBA') + self.conv_r1 = ResStyledBlock(style_size, ch_base * 2, ch_base, seq='CBACBA') + self.up_r0 = ConvStyledBlock(style_size, ch_base, seq='UBA') + self.conv_r00 = ResStyledBlock(style_size, ch_base * 2, ch_base, seq='CBACBA') + self.conv_r01 = ResStyledBlock(style_size, ch_base, out_chan, seq='CBAC') + + self.bypass = bypass if bypass is not None else in_chan == out_chan + + def forward(self, x, s): + if self.bypass: + x0 = x + + x = self.conv_l00(x, s) + y0 = self.conv_l01(x, s) + x = self.down_l0(y0, s) + + y1 = self.conv_l1(x, s) + x = self.down_l1(y1, s) + + x = self.conv_c(x, s) + + x = self.up_r1(x, s) + y1 = narrow_like(y1, x) + x = torch.cat([y1, x], dim=1) + del y1 + x = self.conv_r1(x, s) + + x = self.up_r0(x, s) + y0 = narrow_like(y0, x) + x = torch.cat([y0, x], dim=1) + del y0 + x = self.conv_r00(x, s) + x = self.conv_r01(x, s) + + if self.bypass: + x0 = narrow_like(x0, x) + x += x0 + + return x + + +class StyledVNetD3(nn.Module): + def __init__(self, style_size, in_chan, out_chan, ch_base=64, bypass=None, **kwargs): + """V-Net like network with styles + See `vnet.VNet`. + """ + super().__init__() + + # Define the padding + self.in_pad = 48 + + # activate non-identity skip connection in residual block + # by explicitly setting out_chan + self.conv_l00 = ResStyledBlock(style_size, in_chan, 64, seq='CBACBA') + self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') + self.down_l0 = ConvStyledBlock(style_size, ch_base, seq='DBA') + self.conv_l1 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') + self.down_l1 = ConvStyledBlock(style_size, ch_base, seq='DBA') + self.conv_l2 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') + self.down_l2 = ConvStyledBlock(style_size, ch_base, seq='DBA') + + self.conv_c = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') + + self.up_r2 = ConvStyledBlock(style_size, ch_base, seq='UBA') + self.conv_r2 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA') + self.up_r1 = ConvStyledBlock(style_size, ch_base, seq='UBA') + self.conv_r1 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA') + self.up_r0 = ConvStyledBlock(style_size, ch_base, seq='UBA') + self.conv_r00 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA') + self.conv_r01 = ResStyledBlock(style_size, ch_base, out_chan, seq='CBAC') + + if bypass is None: + self.bypass = in_chan == out_chan + else: + self.bypass = bypass + + def forward(self, x, s): + if self.bypass: + x0 = x + + x = self.conv_l00(x, s) + y0 = self.conv_l01(x, s) + x = self.down_l0(y0, s) + + y1 = self.conv_l1(x, s) + x = self.down_l1(y1, s) + + y2 = self.conv_l2(x, s) + x = self.down_l2(y2, s) + + x = self.conv_c(x, s) + + x = self.up_r2(x, s) + y2 = narrow_like(y2, x) + x = torch.cat([y2, x], dim=1) + del y2 + x = self.conv_r2(x, s) + + x = self.up_r1(x, s) + y1 = narrow_like(y1, x) + x = torch.cat([y1, x], dim=1) + del y1 + x = self.conv_r1(x, s) + + x = self.up_r0(x, s) + y0 = narrow_like(y0, x) + x = torch.cat([y0, x], dim=1) + del y0 + x = self.conv_r00(x, s) + x = self.conv_r01(x, s) + + if self.bypass: + x0 = narrow_like(x0, x) + x += x0 + + return x + + +class StyledVNetD4(nn.Module): + def __init__(self, style_size, in_chan, out_chan, ch_base=32, bypass=None, **kwargs): + """V-Net like network with styles + See `vnet.VNet`. + """ + super().__init__() + + # Define the padding + self.in_pad = 48 + + # activate non-identity skip connection in residual block + # by explicitly setting out_chan + self.conv_l00 = ResStyledBlock(style_size, in_chan, ch_base, seq='CACA') + self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA') + self.down_l0 = ConvStyledBlock(style_size, ch_base, ch_base*2, seq='DA') + self.conv_l1 = ResStyledBlock(style_size, ch_base*2, ch_base*2, seq='CBACBA') + self.down_l1 = ConvStyledBlock(style_size, ch_base*2, ch_base*4, seq='DA') + self.conv_l2 = ResStyledBlock(style_size, ch_base*4, ch_base*4, seq='CBACBA') + self.down_l2 = ConvStyledBlock(style_size, ch_base*4, ch_base*8, seq='DA') + + self.conv_c = ResStyledBlock(style_size, ch_base*8, ch_base*8, seq='CACA') + + self.up_r2 = ConvStyledBlock(style_size, ch_base*8, ch_base*4, seq='UA') + self.conv_r2 = ResStyledBlock(style_size, ch_base*8, ch_base*4, seq='CBACBA') + self.up_r1 = ConvStyledBlock(style_size, ch_base*4, ch_base*2, seq='UA') + self.conv_r1 = ResStyledBlock(style_size, ch_base*4, ch_base*2, seq='CBACBA') + self.up_r0 = ConvStyledBlock(style_size, ch_base*4, ch_base*2, seq='UA') + self.conv_r00 = ResStyledBlock(style_size, ch_base*4, ch_base*2, seq='CBACBA') + self.conv_r01 = ResStyledBlock(style_size, ch_base*2, out_chan, seq='CBAC') + + if bypass is None: + self.bypass = in_chan == out_chan + else: + self.bypass = bypass + + def forward(self, x, s): + if self.bypass: + x0 = x + + x = self.conv_l00(x, s) + y0 = self.conv_l01(x, s) + x = self.down_l0(y0, s) + + y1 = self.conv_l1(x, s) + x = self.down_l1(y1, s) + + y2 = self.conv_l2(x, s) + x = self.down_l2(y2, s) + + x = self.conv_c(x, s) + + x = self.up_r2(x, s) + y2 = narrow_like(y2, x) + x = torch.cat([y2, x], dim=1) + del y2 + x = self.conv_r2(x, s) + + x = self.up_r1(x, s) + y1 = narrow_like(y1, x) + x = torch.cat([y1, x], dim=1) + del y1 + x = self.conv_r1(x, s) + + x = self.up_r0(x, s) + y0 = narrow_like(y0, x) + x = torch.cat([y0, x], dim=1) + del y0 + x = self.conv_r00(x, s) + x = self.conv_r01(x, s) + + if self.bypass: + x0 = narrow_like(x0, x) + x += x0 + + return x