216 lines
8.9 KiB
Python
216 lines
8.9 KiB
Python
"""
|
|
3D U-Net model with optional FiLM layers for style conditioning.
|
|
This model implements a 3D U-Net architecture with downsampling and upsampling blocks, and skip connections.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from .base_class_models import BaseModel
|
|
from .FiLM import FiLM
|
|
|
|
class UNetBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, style_dim=None, batch_norm=True):
|
|
super(UNetBlock, self).__init__()
|
|
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
|
|
|
|
if batch_norm:
|
|
self.norm1 = nn.BatchNorm3d(out_channels)
|
|
self.norm2 = nn.BatchNorm3d(out_channels)
|
|
else:
|
|
self.norm1 = nn.Identity()
|
|
self.norm2 = nn.Identity()
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
|
|
|
def forward(self, x, style=None):
|
|
x = self.relu(self.norm1(self.conv1(x)))
|
|
x = self.relu(self.norm2(self.conv2(x)))
|
|
if self.film:
|
|
x = self.film(x, style)
|
|
return x
|
|
|
|
class UNetEncLayer(nn.Module):
|
|
def __init__(self, in_channels, out_channels, style_dim=None):
|
|
super(UNetEncLayer, self).__init__()
|
|
self.block = UNetBlock(in_channels, out_channels, style_dim)
|
|
self.pool = nn.MaxPool3d(2)
|
|
|
|
def forward(self, x, style=None):
|
|
x = self.block(x, style)
|
|
return x, self.pool(x)
|
|
|
|
class UNetDecLayer(nn.Module):
|
|
def __init__(self, in_channels, out_channels, skip_connection_channels, style_dim=None):
|
|
super(UNetDecLayer, self).__init__()
|
|
self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
|
|
self.block = UNetBlock(out_channels + skip_connection_channels, out_channels, style_dim)
|
|
|
|
def forward(self, x, skip_connection, style=None):
|
|
x = self.up(x)
|
|
if skip_connection is not None:
|
|
x = torch.cat([x, skip_connection], dim=1)
|
|
return self.block(x, style)
|
|
|
|
class UNet3D(BaseModel):
|
|
def __init__(self, N: int = 128,
|
|
in_channels: int = 2,
|
|
out_channels: int = 1,
|
|
style_dim: int = 2,
|
|
depth: int = None,
|
|
device: torch.device = torch.device('cpu'),
|
|
first_layer_channel_exponent: int = 3,
|
|
):
|
|
"""
|
|
3D U-Net model with optional FiLM layers for style conditioning.
|
|
Parameters:
|
|
- N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels.
|
|
- in_channels: Number of input channels (default is 2).
|
|
- out_channels: Number of output channels (default is 1).
|
|
- style_dim: Dimension of the style vector (default is 2).
|
|
- depth: Depth of the U-Net (default is None, which will be computed based on N).
|
|
- device: Device to load the model onto (default is CPU).
|
|
- first_layer_channel_exponent: Exponent for the number of channels in the first layer (default is 3).
|
|
This model implements a 3D U-Net architecture with downsampling and upsampling blocks.
|
|
The model uses convolutional layers with ReLU activations and batch normalization.
|
|
The FiLM layers are used to condition the feature maps on style parameters.
|
|
"""
|
|
|
|
super().__init__(N=N,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
style_parameters=style_dim,
|
|
device=device)
|
|
import numpy as np
|
|
|
|
if depth is not None:
|
|
self.depth = depth
|
|
if np.floor(np.log2(N)).astype(int) - 1 < depth:
|
|
raise ValueError(f"Depth {depth} is too large for input size {N}. Maximum depth is {np.floor(np.log2(N)).astype(int) - 1}.")
|
|
else:
|
|
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
|
self.first_layer_channel_exponent = first_layer_channel_exponent
|
|
|
|
self.enc=[]
|
|
|
|
for i in range(self.depth):
|
|
in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1)
|
|
out_ch = 2**(self.first_layer_channel_exponent + i)
|
|
self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim))
|
|
|
|
self.enc = nn.ModuleList(self.enc)
|
|
|
|
self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth - 1),
|
|
2**(self.first_layer_channel_exponent + self.depth), style_dim)
|
|
|
|
self.dec=[]
|
|
|
|
for i in range(self.depth - 1, -1, -1):
|
|
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
|
|
out_ch = 2**(self.first_layer_channel_exponent + i)
|
|
skip_conn_ch = out_ch
|
|
self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim))
|
|
|
|
self.dec = nn.ModuleList(self.dec)
|
|
|
|
|
|
# self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=3, padding=1)
|
|
self.final = nn.Sequential(
|
|
nn.Conv3d(2**(self.first_layer_channel_exponent), 2**(self.first_layer_channel_exponent), kernel_size=3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
|
|
)
|
|
|
|
|
|
|
|
def forward(self, x, style):
|
|
|
|
out = x
|
|
outlist = []
|
|
|
|
for i in range(self.depth):
|
|
skip, out = self.enc[i](out, style)
|
|
outlist.append(skip)
|
|
|
|
out = self.bottleneck(out, style)
|
|
|
|
for i in range(self.depth):
|
|
out = self.dec[i](out, outlist[self.depth - 1 - i], style)
|
|
|
|
return self.final(out)
|
|
|
|
class UNet3D_Shrink(BaseModel):
|
|
|
|
def __init__(self,
|
|
N: int = 128,
|
|
in_channels: int = 2,
|
|
out_channels: int = 1,
|
|
style_dim: int = 2,
|
|
depth: int = None,
|
|
device: torch.device = torch.device('cpu'),
|
|
first_layer_channel_exponent: int = 3,
|
|
shrink_factor_exponent: int = 1,
|
|
):
|
|
"""
|
|
A 3D U-Net model with optional FiLM layers for style conditioning and a shrink factor.
|
|
It means that the output data is of size (N/shrink_factor, N/shrink_factor, N/shrink_factor), where N is the input size.
|
|
"""
|
|
super().__init__(N=N,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
style_parameters=style_dim,
|
|
device=device)
|
|
import numpy as np
|
|
|
|
if depth is not None:
|
|
self.depth_enc = depth
|
|
if np.floor(np.log2(N)).astype(int) - 1 < depth:
|
|
raise ValueError(f"Depth {depth} is too large for input size {N}. Maximum depth is {np.floor(np.log2(N)).astype(int) - 1}.")
|
|
else:
|
|
self.depth_enc = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
|
|
|
self.depth_dec = self.depth_enc - shrink_factor_exponent # Depth of the U-Net based on input size N and shrink factor
|
|
self.first_layer_channel_exponent = first_layer_channel_exponent
|
|
self.shrink_factor_exponent = shrink_factor_exponent
|
|
|
|
self.enc = []
|
|
for i in range(self.depth_enc):
|
|
in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1)
|
|
out_ch = 2**(self.first_layer_channel_exponent + i)
|
|
self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim))
|
|
|
|
self.enc = nn.ModuleList(self.enc)
|
|
|
|
self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth_enc - 1),
|
|
2**(self.first_layer_channel_exponent + self.depth_enc), style_dim)
|
|
|
|
self.dec = []
|
|
for i in range(self.depth_enc - 1, self.depth_dec - self.depth_enc -1, -1):
|
|
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
|
|
out_ch = 2**(self.first_layer_channel_exponent + i)
|
|
skip_conn_ch = out_ch if i >= self.depth_dec-self.depth_enc else 0
|
|
self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim))
|
|
|
|
self.dec = nn.ModuleList(self.dec)
|
|
|
|
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent+shrink_factor_exponent), out_channels, kernel_size=1)
|
|
|
|
|
|
def forward(self, x, style):
|
|
out = x
|
|
outlist = []
|
|
|
|
for i in range(self.depth_enc):
|
|
skip, out = self.enc[i](out, style)
|
|
outlist.append(skip)
|
|
|
|
out = self.bottleneck(out, style)
|
|
|
|
for i in range(self.depth_dec):
|
|
if i < self.depth_enc:
|
|
out = self.dec[i](out, outlist[self.depth_enc - 1 - i], style)
|
|
else:
|
|
out = self.dec[i](out, None, style)
|
|
|
|
return self.final(out)
|