better UNets with different batch norms and custom depth
This commit is contained in:
parent
c7cf9fe7d4
commit
0cbd7fcc46
1 changed files with 34 additions and 7 deletions
|
@ -9,17 +9,24 @@ from .base_class_models import BaseModel
|
||||||
from .FiLM import FiLM
|
from .FiLM import FiLM
|
||||||
|
|
||||||
class UNetBlock(nn.Module):
|
class UNetBlock(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, style_dim=None):
|
def __init__(self, in_channels, out_channels, style_dim=None, batch_norm=True):
|
||||||
super(UNetBlock, self).__init__()
|
super(UNetBlock, self).__init__()
|
||||||
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
|
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||||
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
|
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||||
self.norm = nn.BatchNorm3d(out_channels)
|
|
||||||
|
if batch_norm:
|
||||||
|
self.norm1 = nn.BatchNorm3d(out_channels)
|
||||||
|
self.norm2 = nn.BatchNorm3d(out_channels)
|
||||||
|
else:
|
||||||
|
self.norm1 = nn.Identity()
|
||||||
|
self.norm2 = nn.Identity()
|
||||||
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
||||||
|
|
||||||
def forward(self, x, style=None):
|
def forward(self, x, style=None):
|
||||||
x = self.relu(self.norm(self.conv1(x)))
|
x = self.relu(self.norm1(self.conv1(x)))
|
||||||
x = self.relu(self.norm(self.conv2(x)))
|
x = self.relu(self.norm2(self.conv2(x)))
|
||||||
if self.film:
|
if self.film:
|
||||||
x = self.film(x, style)
|
x = self.film(x, style)
|
||||||
return x
|
return x
|
||||||
|
@ -51,6 +58,7 @@ class UNet3D(BaseModel):
|
||||||
in_channels: int = 2,
|
in_channels: int = 2,
|
||||||
out_channels: int = 1,
|
out_channels: int = 1,
|
||||||
style_dim: int = 2,
|
style_dim: int = 2,
|
||||||
|
depth: int = None,
|
||||||
device: torch.device = torch.device('cpu'),
|
device: torch.device = torch.device('cpu'),
|
||||||
first_layer_channel_exponent: int = 3,
|
first_layer_channel_exponent: int = 3,
|
||||||
):
|
):
|
||||||
|
@ -61,7 +69,9 @@ class UNet3D(BaseModel):
|
||||||
- in_channels: Number of input channels (default is 2).
|
- in_channels: Number of input channels (default is 2).
|
||||||
- out_channels: Number of output channels (default is 1).
|
- out_channels: Number of output channels (default is 1).
|
||||||
- style_dim: Dimension of the style vector (default is 2).
|
- style_dim: Dimension of the style vector (default is 2).
|
||||||
|
- depth: Depth of the U-Net (default is None, which will be computed based on N).
|
||||||
- device: Device to load the model onto (default is CPU).
|
- device: Device to load the model onto (default is CPU).
|
||||||
|
- first_layer_channel_exponent: Exponent for the number of channels in the first layer (default is 3).
|
||||||
This model implements a 3D U-Net architecture with downsampling and upsampling blocks.
|
This model implements a 3D U-Net architecture with downsampling and upsampling blocks.
|
||||||
The model uses convolutional layers with ReLU activations and batch normalization.
|
The model uses convolutional layers with ReLU activations and batch normalization.
|
||||||
The FiLM layers are used to condition the feature maps on style parameters.
|
The FiLM layers are used to condition the feature maps on style parameters.
|
||||||
|
@ -74,7 +84,12 @@ class UNet3D(BaseModel):
|
||||||
device=device)
|
device=device)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
if depth is not None:
|
||||||
|
self.depth = depth
|
||||||
|
if np.floor(np.log2(N)).astype(int) - 1 < depth:
|
||||||
|
raise ValueError(f"Depth {depth} is too large for input size {N}. Maximum depth is {np.floor(np.log2(N)).astype(int) - 1}.")
|
||||||
|
else:
|
||||||
|
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
||||||
self.first_layer_channel_exponent = first_layer_channel_exponent
|
self.first_layer_channel_exponent = first_layer_channel_exponent
|
||||||
|
|
||||||
self.enc=[]
|
self.enc=[]
|
||||||
|
@ -100,7 +115,12 @@ class UNet3D(BaseModel):
|
||||||
self.dec = nn.ModuleList(self.dec)
|
self.dec = nn.ModuleList(self.dec)
|
||||||
|
|
||||||
|
|
||||||
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
|
# self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=3, padding=1)
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
nn.Conv3d(2**(self.first_layer_channel_exponent), 2**(self.first_layer_channel_exponent), kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,6 +147,7 @@ class UNet3D_Shrink(BaseModel):
|
||||||
in_channels: int = 2,
|
in_channels: int = 2,
|
||||||
out_channels: int = 1,
|
out_channels: int = 1,
|
||||||
style_dim: int = 2,
|
style_dim: int = 2,
|
||||||
|
depth: int = None,
|
||||||
device: torch.device = torch.device('cpu'),
|
device: torch.device = torch.device('cpu'),
|
||||||
first_layer_channel_exponent: int = 3,
|
first_layer_channel_exponent: int = 3,
|
||||||
shrink_factor_exponent: int = 1,
|
shrink_factor_exponent: int = 1,
|
||||||
|
@ -142,7 +163,13 @@ class UNet3D_Shrink(BaseModel):
|
||||||
device=device)
|
device=device)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
self.depth_enc = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
if depth is not None:
|
||||||
|
self.depth_enc = depth
|
||||||
|
if np.floor(np.log2(N)).astype(int) - 1 < depth:
|
||||||
|
raise ValueError(f"Depth {depth} is too large for input size {N}. Maximum depth is {np.floor(np.log2(N)).astype(int) - 1}.")
|
||||||
|
else:
|
||||||
|
self.depth_enc = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
||||||
|
|
||||||
self.depth_dec = self.depth_enc - shrink_factor_exponent # Depth of the U-Net based on input size N and shrink factor
|
self.depth_dec = self.depth_enc - shrink_factor_exponent # Depth of the U-Net based on input size N and shrink factor
|
||||||
self.first_layer_channel_exponent = first_layer_channel_exponent
|
self.first_layer_channel_exponent = first_layer_channel_exponent
|
||||||
self.shrink_factor_exponent = shrink_factor_exponent
|
self.shrink_factor_exponent = shrink_factor_exponent
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue