copied COCA preexisting models
This commit is contained in:
parent
118455567b
commit
36f38ef256
1 changed files with 343 additions and 0 deletions
343
sCOCA_ML/models/StyledVNet_models.py
Normal file
343
sCOCA_ML/models/StyledVNet_models.py
Normal file
|
@ -0,0 +1,343 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from map2map.models.styled_conv import ConvStyledBlock, ResStyledBlock
|
||||
from map2map.models.narrow import narrow_by, narrow_like
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
class SimpleStyledVNet(nn.Module):
|
||||
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
|
||||
"""V-Net like network with styles
|
||||
|
||||
See `vnet.VNet`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='C')
|
||||
self.down_l0 = ConvStyledBlock(style_size, 64, seq='BA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='C')
|
||||
self.down_l1 = ConvStyledBlock(style_size, 64, seq='BA')
|
||||
|
||||
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||
|
||||
self.up_r1 = ConvStyledBlock(style_size, 64, seq='BA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='C')
|
||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq='BA')
|
||||
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='C')
|
||||
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
else:
|
||||
self.bypass = bypass
|
||||
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
x0 = x
|
||||
|
||||
y0 = self.conv_l0(x, s)
|
||||
x = self.down_l0(y0, s)
|
||||
|
||||
y1 = self.conv_l1(x, s)
|
||||
x = self.down_l1(y1, s)
|
||||
|
||||
x = self.conv_c(x, s)
|
||||
|
||||
x = self.up_r1(x, s)
|
||||
y1 = narrow_by(y1, int((y1.shape[2] - x.shape[2])/2))
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x, s)
|
||||
|
||||
x = self.up_r0(x, s)
|
||||
y0 = narrow_by(y0, int((y0.shape[2] - x.shape[2])/2))
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r0(x, s)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_by(x0, 20)
|
||||
x += x0
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class StyledVNet(nn.Module):
|
||||
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
|
||||
"""V-Net like network with styles
|
||||
|
||||
See `vnet.VNet`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# activate non-identity skip connection in residual block
|
||||
# by explicitly setting out_chan
|
||||
self.conv_l0 = ResStyledBlock(style_size, in_chan, 64, seq='C')
|
||||
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DBA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DBA')
|
||||
|
||||
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CBACBA')
|
||||
|
||||
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CBACBA')
|
||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
|
||||
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
else:
|
||||
self.bypass = bypass
|
||||
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
x0 = x
|
||||
|
||||
y0 = self.conv_l0(x, s)
|
||||
x = self.down_l0(y0, s)
|
||||
|
||||
y1 = self.conv_l1(x, s)
|
||||
x = self.down_l1(y1, s)
|
||||
|
||||
x = self.conv_c(x, s)
|
||||
|
||||
x = self.up_r1(x, s)
|
||||
y1 = narrow_like(y1, x)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x, s)
|
||||
|
||||
x = self.up_r0(x, s)
|
||||
y0 = narrow_like(y0, x)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r0(x, s)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_by(x0, 20)
|
||||
x += x0
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StyledVNetD2(nn.Module):
|
||||
def __init__(self, style_size, in_chan, out_chan, bypass=None, ch_base=64, **kwargs):
|
||||
"""V-Net like network with styles.
|
||||
See `vnet.VNet`.
|
||||
|
||||
Args:
|
||||
style_size (int): Size of the style vector.
|
||||
in_chan (int): Number of input channels.
|
||||
out_chan (int): Number of output channels.
|
||||
bypass (bool, optional): Enable or disable bypass connection if in_chan equals out_chan.
|
||||
ch_base (int): Base channel number for the initial layers.
|
||||
ch_l1 (int): Channel number for mid-level layers.
|
||||
ch_out (int): Output channels before the final layer.
|
||||
"""
|
||||
super(StyledVNetD2, self).__init__()
|
||||
|
||||
# Define the padding
|
||||
self.in_pad = 24
|
||||
|
||||
# Adjusted channel numbers in layer definitions
|
||||
self.conv_l00 = ResStyledBlock(style_size, in_chan, ch_base, seq='CBACBA')
|
||||
self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
|
||||
self.down_l0 = ConvStyledBlock(style_size, ch_base, seq='DBA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
|
||||
self.down_l1 = ConvStyledBlock(style_size, ch_base, seq='DBA')
|
||||
|
||||
self.conv_c = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
|
||||
|
||||
self.up_r1 = ConvStyledBlock(style_size, ch_base, seq='UBA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, ch_base * 2, ch_base, seq='CBACBA')
|
||||
self.up_r0 = ConvStyledBlock(style_size, ch_base, seq='UBA')
|
||||
self.conv_r00 = ResStyledBlock(style_size, ch_base * 2, ch_base, seq='CBACBA')
|
||||
self.conv_r01 = ResStyledBlock(style_size, ch_base, out_chan, seq='CBAC')
|
||||
|
||||
self.bypass = bypass if bypass is not None else in_chan == out_chan
|
||||
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
x0 = x
|
||||
|
||||
x = self.conv_l00(x, s)
|
||||
y0 = self.conv_l01(x, s)
|
||||
x = self.down_l0(y0, s)
|
||||
|
||||
y1 = self.conv_l1(x, s)
|
||||
x = self.down_l1(y1, s)
|
||||
|
||||
x = self.conv_c(x, s)
|
||||
|
||||
x = self.up_r1(x, s)
|
||||
y1 = narrow_like(y1, x)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x, s)
|
||||
|
||||
x = self.up_r0(x, s)
|
||||
y0 = narrow_like(y0, x)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r00(x, s)
|
||||
x = self.conv_r01(x, s)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_like(x0, x)
|
||||
x += x0
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StyledVNetD3(nn.Module):
|
||||
def __init__(self, style_size, in_chan, out_chan, ch_base=64, bypass=None, **kwargs):
|
||||
"""V-Net like network with styles
|
||||
See `vnet.VNet`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Define the padding
|
||||
self.in_pad = 48
|
||||
|
||||
# activate non-identity skip connection in residual block
|
||||
# by explicitly setting out_chan
|
||||
self.conv_l00 = ResStyledBlock(style_size, in_chan, 64, seq='CBACBA')
|
||||
self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
|
||||
self.down_l0 = ConvStyledBlock(style_size, ch_base, seq='DBA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
|
||||
self.down_l1 = ConvStyledBlock(style_size, ch_base, seq='DBA')
|
||||
self.conv_l2 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
|
||||
self.down_l2 = ConvStyledBlock(style_size, ch_base, seq='DBA')
|
||||
|
||||
self.conv_c = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
|
||||
|
||||
self.up_r2 = ConvStyledBlock(style_size, ch_base, seq='UBA')
|
||||
self.conv_r2 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA')
|
||||
self.up_r1 = ConvStyledBlock(style_size, ch_base, seq='UBA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA')
|
||||
self.up_r0 = ConvStyledBlock(style_size, ch_base, seq='UBA')
|
||||
self.conv_r00 = ResStyledBlock(style_size, ch_base*2, ch_base, seq='CBACBA')
|
||||
self.conv_r01 = ResStyledBlock(style_size, ch_base, out_chan, seq='CBAC')
|
||||
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
else:
|
||||
self.bypass = bypass
|
||||
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
x0 = x
|
||||
|
||||
x = self.conv_l00(x, s)
|
||||
y0 = self.conv_l01(x, s)
|
||||
x = self.down_l0(y0, s)
|
||||
|
||||
y1 = self.conv_l1(x, s)
|
||||
x = self.down_l1(y1, s)
|
||||
|
||||
y2 = self.conv_l2(x, s)
|
||||
x = self.down_l2(y2, s)
|
||||
|
||||
x = self.conv_c(x, s)
|
||||
|
||||
x = self.up_r2(x, s)
|
||||
y2 = narrow_like(y2, x)
|
||||
x = torch.cat([y2, x], dim=1)
|
||||
del y2
|
||||
x = self.conv_r2(x, s)
|
||||
|
||||
x = self.up_r1(x, s)
|
||||
y1 = narrow_like(y1, x)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x, s)
|
||||
|
||||
x = self.up_r0(x, s)
|
||||
y0 = narrow_like(y0, x)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r00(x, s)
|
||||
x = self.conv_r01(x, s)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_like(x0, x)
|
||||
x += x0
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StyledVNetD4(nn.Module):
|
||||
def __init__(self, style_size, in_chan, out_chan, ch_base=32, bypass=None, **kwargs):
|
||||
"""V-Net like network with styles
|
||||
See `vnet.VNet`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Define the padding
|
||||
self.in_pad = 48
|
||||
|
||||
# activate non-identity skip connection in residual block
|
||||
# by explicitly setting out_chan
|
||||
self.conv_l00 = ResStyledBlock(style_size, in_chan, ch_base, seq='CACA')
|
||||
self.conv_l01 = ResStyledBlock(style_size, ch_base, ch_base, seq='CBACBA')
|
||||
self.down_l0 = ConvStyledBlock(style_size, ch_base, ch_base*2, seq='DA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, ch_base*2, ch_base*2, seq='CBACBA')
|
||||
self.down_l1 = ConvStyledBlock(style_size, ch_base*2, ch_base*4, seq='DA')
|
||||
self.conv_l2 = ResStyledBlock(style_size, ch_base*4, ch_base*4, seq='CBACBA')
|
||||
self.down_l2 = ConvStyledBlock(style_size, ch_base*4, ch_base*8, seq='DA')
|
||||
|
||||
self.conv_c = ResStyledBlock(style_size, ch_base*8, ch_base*8, seq='CACA')
|
||||
|
||||
self.up_r2 = ConvStyledBlock(style_size, ch_base*8, ch_base*4, seq='UA')
|
||||
self.conv_r2 = ResStyledBlock(style_size, ch_base*8, ch_base*4, seq='CBACBA')
|
||||
self.up_r1 = ConvStyledBlock(style_size, ch_base*4, ch_base*2, seq='UA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, ch_base*4, ch_base*2, seq='CBACBA')
|
||||
self.up_r0 = ConvStyledBlock(style_size, ch_base*4, ch_base*2, seq='UA')
|
||||
self.conv_r00 = ResStyledBlock(style_size, ch_base*4, ch_base*2, seq='CBACBA')
|
||||
self.conv_r01 = ResStyledBlock(style_size, ch_base*2, out_chan, seq='CBAC')
|
||||
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
else:
|
||||
self.bypass = bypass
|
||||
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
x0 = x
|
||||
|
||||
x = self.conv_l00(x, s)
|
||||
y0 = self.conv_l01(x, s)
|
||||
x = self.down_l0(y0, s)
|
||||
|
||||
y1 = self.conv_l1(x, s)
|
||||
x = self.down_l1(y1, s)
|
||||
|
||||
y2 = self.conv_l2(x, s)
|
||||
x = self.down_l2(y2, s)
|
||||
|
||||
x = self.conv_c(x, s)
|
||||
|
||||
x = self.up_r2(x, s)
|
||||
y2 = narrow_like(y2, x)
|
||||
x = torch.cat([y2, x], dim=1)
|
||||
del y2
|
||||
x = self.conv_r2(x, s)
|
||||
|
||||
x = self.up_r1(x, s)
|
||||
y1 = narrow_like(y1, x)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x, s)
|
||||
|
||||
x = self.up_r0(x, s)
|
||||
y0 = narrow_like(y0, x)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r00(x, s)
|
||||
x = self.conv_r01(x, s)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_like(x0, x)
|
||||
x += x0
|
||||
|
||||
return x
|
Loading…
Add table
Add a link
Reference in a new issue