copied COCA preexisting models

This commit is contained in:
Mayeul Aubin 2025-06-25 17:00:23 +02:00
parent 118455567b
commit 36f38ef256

View file

@ -0,0 +1,343 @@
import torch.nn as nn
import torch
from map2map.models.styled_conv import ConvStyledBlock, ResStyledBlock
from map2map.models.narrow import narrow_by, narrow_like
import math
import numpy as np
class SimpleStyledVNet(nn.Module):
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
"""V-Net like network with styles
See `vnet.VNet`.
"""
super().__init__()
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='C')
self.down_l0 = ConvStyledBlock(style_size, 64, seq='BA')
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='C')
self.down_l1 = ConvStyledBlock(style_size, 64, seq='BA')
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
self.up_r1 = ConvStyledBlock(style_size, 64, seq='BA')
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='C')
self.up_r0 = ConvStyledBlock(style_size, 64, seq='BA')
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='C')
if bypass is None:
self.bypass = in_chan == out_chan
else:
self.bypass = bypass
def forward(self, x, s):
if self.bypass:
x0 = x
y0 = self.conv_l0(x, s)
x = self.down_l0(y0, s)
y1 = self.conv_l1(x, s)
x = self.down_l1(y1, s)
x = self.conv_c(x, s)
x = self.up_r1(x, s)
y1 = narrow_by(y1, int((y1.shape[2] - x.shape[2])/2))
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x, s)
x = self.up_r0(x, s)
y0 = narrow_by(y0, int((y0.shape[2] - x.shape[2])/2))
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r0(x, s)
if self.bypass:
x0 = narrow_by(x0, 20)
x += x0
return x
class StyledVNet(nn.Module):
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
"""V-Net like network with styles
See `vnet.VNet`.
"""
super().__init__()
# activate non-identity skip connection in residual block
# by explicitly setting out_chan
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='C')
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA')
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA')
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA')
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA')
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
if bypass is None:
self.bypass = in_chan == out_chan
else:
self.bypass = bypass
def forward(self, x, s):
if self.bypass:
x0 = x
y0 = self.conv_l0(x, s)
x = self.down_l0(y0, s)
y1 = self.conv_l1(x, s)
x = self.down_l1(y1, s)
x = self.conv_c(x, s)
x = self.up_r1(x, s)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x, s)
x = self.up_r0(x, s)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r0(x, s)
if self.bypass:
x0 = narrow_by(x0, 20)
x += x0
return x
class StyledVNetD2(nn.Module):
def __init__(self, style_size, in_chan, out_chan, bypass=None, ch_base=64, **kwargs):
"""V-Net like network with styles.
See `vnet.VNet`.
Args:
style_size (int): Size of the style vector.
in_chan (int): Number of input channels.
out_chan (int): Number of output channels.
bypass (bool, optional): Enable or disable bypass connection if in_chan equals out_chan.
ch_base (int): Base channel number for the initial layers.
ch_l1 (int): Channel number for mid-level layers.
ch_out (int): Output channels before the final layer.
"""
super(StyledVNetD2, self).__init__()
# Define the padding
self.in_pad = 24
# Adjusted channel numbers in layer definitions
self.conv_l00 = ResStyledBlock(style_size, in_chan, ch_base, seq='CBACBA')
self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
self.down_l0 = ConvStyledBlock(style_size, ch_base, seq='DBA')
self.conv_l1 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
self.down_l1 = ConvStyledBlock(style_size, ch_base, seq='DBA')
self.conv_c = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
self.up_r1 = ConvStyledBlock(style_size, ch_base, seq='UBA')
self.conv_r1 = ResStyledBlock(style_size, ch_base * 2, ch_base, seq='CBACBA')
self.up_r0 = ConvStyledBlock(style_size, ch_base, seq='UBA')
self.conv_r00 = ResStyledBlock(style_size, ch_base * 2, ch_base, seq='CBACBA')
self.conv_r01 = ResStyledBlock(style_size, ch_base, out_chan, seq='CBAC')
self.bypass = bypass if bypass is not None else in_chan == out_chan
def forward(self, x, s):
if self.bypass:
x0 = x
x = self.conv_l00(x, s)
y0 = self.conv_l01(x, s)
x = self.down_l0(y0, s)
y1 = self.conv_l1(x, s)
x = self.down_l1(y1, s)
x = self.conv_c(x, s)
x = self.up_r1(x, s)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x, s)
x = self.up_r0(x, s)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r00(x, s)
x = self.conv_r01(x, s)
if self.bypass:
x0 = narrow_like(x0, x)
x += x0
return x
class StyledVNetD3(nn.Module):
def __init__(self, style_size, in_chan, out_chan, ch_base=64, bypass=None, **kwargs):
"""V-Net like network with styles
See `vnet.VNet`.
"""
super().__init__()
# Define the padding
self.in_pad = 48
# activate non-identity skip connection in residual block
# by explicitly setting out_chan
self.conv_l00 = ResStyledBlock(style_size, in_chan, 64, seq='CBACBA')
self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
self.down_l0 = ConvStyledBlock(style_size, ch_base, seq='DBA')
self.conv_l1 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
self.down_l1 = ConvStyledBlock(style_size, ch_base, seq='DBA')
self.conv_l2 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
self.down_l2 = ConvStyledBlock(style_size, ch_base, seq='DBA')
self.conv_c = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
self.up_r2 = ConvStyledBlock(style_size, ch_base, seq='UBA')
self.conv_r2 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA')
self.up_r1 = ConvStyledBlock(style_size, ch_base, seq='UBA')
self.conv_r1 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA')
self.up_r0 = ConvStyledBlock(style_size, ch_base, seq='UBA')
self.conv_r00 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA')
self.conv_r01 = ResStyledBlock(style_size, ch_base, out_chan, seq='CBAC')
if bypass is None:
self.bypass = in_chan == out_chan
else:
self.bypass = bypass
def forward(self, x, s):
if self.bypass:
x0 = x
x = self.conv_l00(x, s)
y0 = self.conv_l01(x, s)
x = self.down_l0(y0, s)
y1 = self.conv_l1(x, s)
x = self.down_l1(y1, s)
y2 = self.conv_l2(x, s)
x = self.down_l2(y2, s)
x = self.conv_c(x, s)
x = self.up_r2(x, s)
y2 = narrow_like(y2, x)
x = torch.cat([y2, x], dim=1)
del y2
x = self.conv_r2(x, s)
x = self.up_r1(x, s)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x, s)
x = self.up_r0(x, s)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r00(x, s)
x = self.conv_r01(x, s)
if self.bypass:
x0 = narrow_like(x0, x)
x += x0
return x
class StyledVNetD4(nn.Module):
def __init__(self, style_size, in_chan, out_chan, ch_base=32, bypass=None, **kwargs):
"""V-Net like network with styles
See `vnet.VNet`.
"""
super().__init__()
# Define the padding
self.in_pad = 48
# activate non-identity skip connection in residual block
# by explicitly setting out_chan
self.conv_l00 = ResStyledBlock(style_size, in_chan, ch_base, seq='CACA')
self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
self.down_l0 = ConvStyledBlock(style_size, ch_base, ch_base*2, seq='DA')
self.conv_l1 = ResStyledBlock(style_size, ch_base*2, ch_base*2, seq='CBACBA')
self.down_l1 = ConvStyledBlock(style_size, ch_base*2, ch_base*4, seq='DA')
self.conv_l2 = ResStyledBlock(style_size, ch_base*4, ch_base*4, seq='CBACBA')
self.down_l2 = ConvStyledBlock(style_size, ch_base*4, ch_base*8, seq='DA')
self.conv_c = ResStyledBlock(style_size, ch_base*8, ch_base*8, seq='CACA')
self.up_r2 = ConvStyledBlock(style_size, ch_base*8, ch_base*4, seq='UA')
self.conv_r2 = ResStyledBlock(style_size, ch_base*8, ch_base*4, seq='CBACBA')
self.up_r1 = ConvStyledBlock(style_size, ch_base*4, ch_base*2, seq='UA')
self.conv_r1 = ResStyledBlock(style_size, ch_base*4, ch_base*2, seq='CBACBA')
self.up_r0 = ConvStyledBlock(style_size, ch_base*4, ch_base*2, seq='UA')
self.conv_r00 = ResStyledBlock(style_size, ch_base*4, ch_base*2, seq='CBACBA')
self.conv_r01 = ResStyledBlock(style_size, ch_base*2, out_chan, seq='CBAC')
if bypass is None:
self.bypass = in_chan == out_chan
else:
self.bypass = bypass
def forward(self, x, s):
if self.bypass:
x0 = x
x = self.conv_l00(x, s)
y0 = self.conv_l01(x, s)
x = self.down_l0(y0, s)
y1 = self.conv_l1(x, s)
x = self.down_l1(y1, s)
y2 = self.conv_l2(x, s)
x = self.down_l2(y2, s)
x = self.conv_c(x, s)
x = self.up_r2(x, s)
y2 = narrow_like(y2, x)
x = torch.cat([y2, x], dim=1)
del y2
x = self.conv_r2(x, s)
x = self.up_r1(x, s)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x, s)
x = self.up_r0(x, s)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r00(x, s)
x = self.conv_r01(x, s)
if self.bypass:
x0 = narrow_like(x0, x)
x += x0
return x