diff --git a/sCOCA_ML/models/UNet_models.py b/sCOCA_ML/models/UNet_models.py index fcd8eef..59a4398 100644 --- a/sCOCA_ML/models/UNet_models.py +++ b/sCOCA_ML/models/UNet_models.py @@ -1,18 +1,12 @@ +""" +3D U-Net model with optional FiLM layers for style conditioning. +This model implements a 3D U-Net architecture with downsampling and upsampling blocks, and skip connections. +""" + import torch import torch.nn as nn -import torch.nn.functional as F from .base_class_models import BaseModel - -class FiLM(nn.Module): - def __init__(self, num_features, style_dim): - super(FiLM, self).__init__() - self.gamma = nn.Linear(style_dim, num_features) - self.beta = nn.Linear(style_dim, num_features) - - def forward(self, x, style): - gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - return gamma * x + beta +from .FiLM import FiLM class UNetBlock(nn.Module): def __init__(self, in_channels, out_channels, style_dim=None): @@ -48,7 +42,8 @@ class UNetDecLayer(nn.Module): def forward(self, x, skip_connection, style=None): x = self.up(x) - x = torch.cat([x, skip_connection], dim=1) + if skip_connection is not None: + x = torch.cat([x, skip_connection], dim=1) return self.block(x, style) class UNet3D(BaseModel): @@ -125,3 +120,70 @@ class UNet3D(BaseModel): return self.final(out) +class UNet3D_Shrink(BaseModel): + + def __init__(self, + N: int = 128, + in_channels: int = 2, + out_channels: int = 1, + style_dim: int = 2, + device: torch.device = torch.device('cpu'), + first_layer_channel_exponent: int = 3, + shrink_factor_exponent: int = 1, + ): + """ + A 3D U-Net model with optional FiLM layers for style conditioning and a shrink factor. + It means that the output data is of size (N/shrink_factor, N/shrink_factor, N/shrink_factor), where N is the input size. + """ + super().__init__(N=N, + in_channels=in_channels, + out_channels=out_channels, + style_parameters=style_dim, + device=device) + import numpy as np + + self.depth_enc = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N + self.depth_dec = self.depth_enc - shrink_factor_exponent # Depth of the U-Net based on input size N and shrink factor + self.first_layer_channel_exponent = first_layer_channel_exponent + self.shrink_factor_exponent = shrink_factor_exponent + + self.enc = [] + for i in range(self.depth_enc): + in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1) + out_ch = 2**(self.first_layer_channel_exponent + i) + self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim)) + + self.enc = nn.ModuleList(self.enc) + + self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth_enc - 1), + 2**(self.first_layer_channel_exponent + self.depth_enc), style_dim) + + self.dec = [] + for i in range(self.depth_dec - 1, -1, -1): + in_ch = 2**(self.first_layer_channel_exponent + i + 1) + out_ch = 2**(self.first_layer_channel_exponent + i) + skip_conn_ch = out_ch if i >= self.depth_dec-self.depth_enc else 0 + self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim)) + + self.dec = nn.ModuleList(self.dec) + + self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1) + + + def forward(self, x, style): + out = x + outlist = [] + + for i in range(self.depth_enc): + skip, out = self.enc[i](out, style) + outlist.append(skip) + + out = self.bottleneck(out, style) + + for i in range(self.depth_dec): + if i < self.depth_enc: + out = self.dec[i](out, outlist[self.depth_enc - 1 - i], style) + else: + out = self.dec[i](out, None, style) + + return self.final(out)