better UNets with different batch norms and custom depth

This commit is contained in:
Mayeul Aubin 2025-06-30 14:30:10 +02:00
parent c7cf9fe7d4
commit 0cbd7fcc46

View file

@ -9,17 +9,24 @@ from .base_class_models import BaseModel
from .FiLM import FiLM from .FiLM import FiLM
class UNetBlock(nn.Module): class UNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, style_dim=None): def __init__(self, in_channels, out_channels, style_dim=None, batch_norm=True):
super(UNetBlock, self).__init__() super(UNetBlock, self).__init__()
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
self.norm = nn.BatchNorm3d(out_channels)
if batch_norm:
self.norm1 = nn.BatchNorm3d(out_channels)
self.norm2 = nn.BatchNorm3d(out_channels)
else:
self.norm1 = nn.Identity()
self.norm2 = nn.Identity()
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.film = FiLM(out_channels, style_dim) if style_dim else None self.film = FiLM(out_channels, style_dim) if style_dim else None
def forward(self, x, style=None): def forward(self, x, style=None):
x = self.relu(self.norm(self.conv1(x))) x = self.relu(self.norm1(self.conv1(x)))
x = self.relu(self.norm(self.conv2(x))) x = self.relu(self.norm2(self.conv2(x)))
if self.film: if self.film:
x = self.film(x, style) x = self.film(x, style)
return x return x
@ -51,6 +58,7 @@ class UNet3D(BaseModel):
in_channels: int = 2, in_channels: int = 2,
out_channels: int = 1, out_channels: int = 1,
style_dim: int = 2, style_dim: int = 2,
depth: int = None,
device: torch.device = torch.device('cpu'), device: torch.device = torch.device('cpu'),
first_layer_channel_exponent: int = 3, first_layer_channel_exponent: int = 3,
): ):
@ -61,7 +69,9 @@ class UNet3D(BaseModel):
- in_channels: Number of input channels (default is 2). - in_channels: Number of input channels (default is 2).
- out_channels: Number of output channels (default is 1). - out_channels: Number of output channels (default is 1).
- style_dim: Dimension of the style vector (default is 2). - style_dim: Dimension of the style vector (default is 2).
- depth: Depth of the U-Net (default is None, which will be computed based on N).
- device: Device to load the model onto (default is CPU). - device: Device to load the model onto (default is CPU).
- first_layer_channel_exponent: Exponent for the number of channels in the first layer (default is 3).
This model implements a 3D U-Net architecture with downsampling and upsampling blocks. This model implements a 3D U-Net architecture with downsampling and upsampling blocks.
The model uses convolutional layers with ReLU activations and batch normalization. The model uses convolutional layers with ReLU activations and batch normalization.
The FiLM layers are used to condition the feature maps on style parameters. The FiLM layers are used to condition the feature maps on style parameters.
@ -74,7 +84,12 @@ class UNet3D(BaseModel):
device=device) device=device)
import numpy as np import numpy as np
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N if depth is not None:
self.depth = depth
if np.floor(np.log2(N)).astype(int) - 1 < depth:
raise ValueError(f"Depth {depth} is too large for input size {N}. Maximum depth is {np.floor(np.log2(N)).astype(int) - 1}.")
else:
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
self.first_layer_channel_exponent = first_layer_channel_exponent self.first_layer_channel_exponent = first_layer_channel_exponent
self.enc=[] self.enc=[]
@ -100,7 +115,12 @@ class UNet3D(BaseModel):
self.dec = nn.ModuleList(self.dec) self.dec = nn.ModuleList(self.dec)
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1) # self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=3, padding=1)
self.final = nn.Sequential(
nn.Conv3d(2**(self.first_layer_channel_exponent), 2**(self.first_layer_channel_exponent), kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
)
@ -127,6 +147,7 @@ class UNet3D_Shrink(BaseModel):
in_channels: int = 2, in_channels: int = 2,
out_channels: int = 1, out_channels: int = 1,
style_dim: int = 2, style_dim: int = 2,
depth: int = None,
device: torch.device = torch.device('cpu'), device: torch.device = torch.device('cpu'),
first_layer_channel_exponent: int = 3, first_layer_channel_exponent: int = 3,
shrink_factor_exponent: int = 1, shrink_factor_exponent: int = 1,
@ -142,7 +163,13 @@ class UNet3D_Shrink(BaseModel):
device=device) device=device)
import numpy as np import numpy as np
self.depth_enc = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N if depth is not None:
self.depth_enc = depth
if np.floor(np.log2(N)).astype(int) - 1 < depth:
raise ValueError(f"Depth {depth} is too large for input size {N}. Maximum depth is {np.floor(np.log2(N)).astype(int) - 1}.")
else:
self.depth_enc = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
self.depth_dec = self.depth_enc - shrink_factor_exponent # Depth of the U-Net based on input size N and shrink factor self.depth_dec = self.depth_enc - shrink_factor_exponent # Depth of the U-Net based on input size N and shrink factor
self.first_layer_channel_exponent = first_layer_channel_exponent self.first_layer_channel_exponent = first_layer_channel_exponent
self.shrink_factor_exponent = shrink_factor_exponent self.shrink_factor_exponent = shrink_factor_exponent