UNet shrinking
This commit is contained in:
parent
37733e82e0
commit
f0b828dc4d
1 changed files with 75 additions and 13 deletions
|
@ -1,18 +1,12 @@
|
|||
"""
|
||||
3D U-Net model with optional FiLM layers for style conditioning.
|
||||
This model implements a 3D U-Net architecture with downsampling and upsampling blocks, and skip connections.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .base_class_models import BaseModel
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, num_features, style_dim):
|
||||
super(FiLM, self).__init__()
|
||||
self.gamma = nn.Linear(style_dim, num_features)
|
||||
self.beta = nn.Linear(style_dim, num_features)
|
||||
|
||||
def forward(self, x, style):
|
||||
gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
return gamma * x + beta
|
||||
from .FiLM import FiLM
|
||||
|
||||
class UNetBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, style_dim=None):
|
||||
|
@ -48,7 +42,8 @@ class UNetDecLayer(nn.Module):
|
|||
|
||||
def forward(self, x, skip_connection, style=None):
|
||||
x = self.up(x)
|
||||
x = torch.cat([x, skip_connection], dim=1)
|
||||
if skip_connection is not None:
|
||||
x = torch.cat([x, skip_connection], dim=1)
|
||||
return self.block(x, style)
|
||||
|
||||
class UNet3D(BaseModel):
|
||||
|
@ -125,3 +120,70 @@ class UNet3D(BaseModel):
|
|||
|
||||
return self.final(out)
|
||||
|
||||
class UNet3D_Shrink(BaseModel):
|
||||
|
||||
def __init__(self,
|
||||
N: int = 128,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 1,
|
||||
style_dim: int = 2,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
first_layer_channel_exponent: int = 3,
|
||||
shrink_factor_exponent: int = 1,
|
||||
):
|
||||
"""
|
||||
A 3D U-Net model with optional FiLM layers for style conditioning and a shrink factor.
|
||||
It means that the output data is of size (N/shrink_factor, N/shrink_factor, N/shrink_factor), where N is the input size.
|
||||
"""
|
||||
super().__init__(N=N,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
style_parameters=style_dim,
|
||||
device=device)
|
||||
import numpy as np
|
||||
|
||||
self.depth_enc = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
||||
self.depth_dec = self.depth_enc - shrink_factor_exponent # Depth of the U-Net based on input size N and shrink factor
|
||||
self.first_layer_channel_exponent = first_layer_channel_exponent
|
||||
self.shrink_factor_exponent = shrink_factor_exponent
|
||||
|
||||
self.enc = []
|
||||
for i in range(self.depth_enc):
|
||||
in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1)
|
||||
out_ch = 2**(self.first_layer_channel_exponent + i)
|
||||
self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim))
|
||||
|
||||
self.enc = nn.ModuleList(self.enc)
|
||||
|
||||
self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth_enc - 1),
|
||||
2**(self.first_layer_channel_exponent + self.depth_enc), style_dim)
|
||||
|
||||
self.dec = []
|
||||
for i in range(self.depth_dec - 1, -1, -1):
|
||||
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
|
||||
out_ch = 2**(self.first_layer_channel_exponent + i)
|
||||
skip_conn_ch = out_ch if i >= self.depth_dec-self.depth_enc else 0
|
||||
self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim))
|
||||
|
||||
self.dec = nn.ModuleList(self.dec)
|
||||
|
||||
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
|
||||
|
||||
|
||||
def forward(self, x, style):
|
||||
out = x
|
||||
outlist = []
|
||||
|
||||
for i in range(self.depth_enc):
|
||||
skip, out = self.enc[i](out, style)
|
||||
outlist.append(skip)
|
||||
|
||||
out = self.bottleneck(out, style)
|
||||
|
||||
for i in range(self.depth_dec):
|
||||
if i < self.depth_enc:
|
||||
out = self.dec[i](out, outlist[self.depth_enc - 1 - i], style)
|
||||
else:
|
||||
out = self.dec[i](out, None, style)
|
||||
|
||||
return self.final(out)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue