import torch import torch.nn as nn import torch.nn.functional as F from .base_class_models import BaseModel class FiLM(nn.Module): def __init__(self, num_features, style_dim): super(FiLM, self).__init__() self.gamma = nn.Linear(style_dim, num_features) self.beta = nn.Linear(style_dim, num_features) def forward(self, x, style): gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) return gamma * x + beta class UNetBlock(nn.Module): def __init__(self, in_channels, out_channels, style_dim=None): super(UNetBlock, self).__init__() self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) self.norm = nn.BatchNorm3d(out_channels) self.relu = nn.ReLU(inplace=True) self.film = FiLM(out_channels, style_dim) if style_dim else None def forward(self, x, style=None): x = self.relu(self.norm(self.conv1(x))) x = self.relu(self.norm(self.conv2(x))) if self.film: x = self.film(x, style) return x class UNetEncLayer(nn.Module): def __init__(self, in_channels, out_channels, style_dim=None): super(UNetEncLayer, self).__init__() self.block = UNetBlock(in_channels, out_channels, style_dim) self.pool = nn.MaxPool3d(2) def forward(self, x, style=None): x = self.block(x, style) return x, self.pool(x) class UNetDecLayer(nn.Module): def __init__(self, in_channels, out_channels, skip_connection_channels, style_dim=None): super(UNetDecLayer, self).__init__() self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2) self.block = UNetBlock(out_channels + skip_connection_channels, out_channels, style_dim) def forward(self, x, skip_connection, style=None): x = self.up(x) x = torch.cat([x, skip_connection], dim=1) return self.block(x, style) class UNet3D(BaseModel): def __init__(self, N: int = 128, in_channels: int = 2, out_channels: int = 1, style_dim: int = 2, device: torch.device = torch.device('cpu')): """ 3D U-Net model with optional FiLM layers for style conditioning. Parameters: - N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels. - in_channels: Number of input channels (default is 2). - out_channels: Number of output channels (default is 1). - style_dim: Dimension of the style vector (default is 2). - device: Device to load the model onto (default is CPU). This model implements a 3D U-Net architecture with downsampling and upsampling blocks. The model uses convolutional layers with ReLU activations and batch normalization. The FiLM layers are used to condition the feature maps on style parameters. """ super().__init__(N=N, in_channels=in_channels, out_channels=out_channels, style_parameters=style_dim, device=device) import numpy as np self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N self.first_layer_channel_exponent = 3 self.enc=[] for i in range(self.depth): in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1) out_ch = 2**(self.first_layer_channel_exponent + i) self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim)) self.enc = nn.ModuleList(self.enc) self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth - 1), 2**(self.first_layer_channel_exponent + self.depth), style_dim) self.dec=[] for i in range(self.depth - 1, -1, -1): in_ch = 2**(self.first_layer_channel_exponent + i + 1) out_ch = 2**(self.first_layer_channel_exponent + i) skip_conn_ch = out_ch self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim)) self.dec = nn.ModuleList(self.dec) self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1) def forward(self, x, style): out = x outlist = [] for i in range(self.depth): skip, out = self.enc[i](out, style) outlist.append(skip) out = self.bottleneck(out, style) for i in range(self.depth): out = self.dec[i](out, outlist[self.depth - 1 - i], style) return self.final(out)