ML_GravPotBCs/sCOCA_ML/models/UNet_models.py

216 lines
8.9 KiB
Python

"""
3D U-Net model with optional FiLM layers for style conditioning.
This model implements a 3D U-Net architecture with downsampling and upsampling blocks, and skip connections.
"""
import torch
import torch.nn as nn
from .base_class_models import BaseModel
from .FiLM import FiLM
class UNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, style_dim=None, batch_norm=True):
super(UNetBlock, self).__init__()
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
if batch_norm:
self.norm1 = nn.BatchNorm3d(out_channels)
self.norm2 = nn.BatchNorm3d(out_channels)
else:
self.norm1 = nn.Identity()
self.norm2 = nn.Identity()
self.relu = nn.ReLU(inplace=True)
self.film = FiLM(out_channels, style_dim) if style_dim else None
def forward(self, x, style=None):
x = self.relu(self.norm1(self.conv1(x)))
x = self.relu(self.norm2(self.conv2(x)))
if self.film:
x = self.film(x, style)
return x
class UNetEncLayer(nn.Module):
def __init__(self, in_channels, out_channels, style_dim=None):
super(UNetEncLayer, self).__init__()
self.block = UNetBlock(in_channels, out_channels, style_dim)
self.pool = nn.MaxPool3d(2)
def forward(self, x, style=None):
x = self.block(x, style)
return x, self.pool(x)
class UNetDecLayer(nn.Module):
def __init__(self, in_channels, out_channels, skip_connection_channels, style_dim=None):
super(UNetDecLayer, self).__init__()
self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
self.block = UNetBlock(out_channels + skip_connection_channels, out_channels, style_dim)
def forward(self, x, skip_connection, style=None):
x = self.up(x)
if skip_connection is not None:
x = torch.cat([x, skip_connection], dim=1)
return self.block(x, style)
class UNet3D(BaseModel):
def __init__(self, N: int = 128,
in_channels: int = 2,
out_channels: int = 1,
style_dim: int = 2,
depth: int = None,
device: torch.device = torch.device('cpu'),
first_layer_channel_exponent: int = 3,
):
"""
3D U-Net model with optional FiLM layers for style conditioning.
Parameters:
- N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels.
- in_channels: Number of input channels (default is 2).
- out_channels: Number of output channels (default is 1).
- style_dim: Dimension of the style vector (default is 2).
- depth: Depth of the U-Net (default is None, which will be computed based on N).
- device: Device to load the model onto (default is CPU).
- first_layer_channel_exponent: Exponent for the number of channels in the first layer (default is 3).
This model implements a 3D U-Net architecture with downsampling and upsampling blocks.
The model uses convolutional layers with ReLU activations and batch normalization.
The FiLM layers are used to condition the feature maps on style parameters.
"""
super().__init__(N=N,
in_channels=in_channels,
out_channels=out_channels,
style_parameters=style_dim,
device=device)
import numpy as np
if depth is not None:
self.depth = depth
if np.floor(np.log2(N)).astype(int) - 1 < depth:
raise ValueError(f"Depth {depth} is too large for input size {N}. Maximum depth is {np.floor(np.log2(N)).astype(int) - 1}.")
else:
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
self.first_layer_channel_exponent = first_layer_channel_exponent
self.enc=[]
for i in range(self.depth):
in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1)
out_ch = 2**(self.first_layer_channel_exponent + i)
self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim))
self.enc = nn.ModuleList(self.enc)
self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth - 1),
2**(self.first_layer_channel_exponent + self.depth), style_dim)
self.dec=[]
for i in range(self.depth - 1, -1, -1):
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
out_ch = 2**(self.first_layer_channel_exponent + i)
skip_conn_ch = out_ch
self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim))
self.dec = nn.ModuleList(self.dec)
# self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=3, padding=1)
self.final = nn.Sequential(
nn.Conv3d(2**(self.first_layer_channel_exponent), 2**(self.first_layer_channel_exponent), kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
)
def forward(self, x, style):
out = x
outlist = []
for i in range(self.depth):
skip, out = self.enc[i](out, style)
outlist.append(skip)
out = self.bottleneck(out, style)
for i in range(self.depth):
out = self.dec[i](out, outlist[self.depth - 1 - i], style)
return self.final(out)
class UNet3D_Shrink(BaseModel):
def __init__(self,
N: int = 128,
in_channels: int = 2,
out_channels: int = 1,
style_dim: int = 2,
depth: int = None,
device: torch.device = torch.device('cpu'),
first_layer_channel_exponent: int = 3,
shrink_factor_exponent: int = 1,
):
"""
A 3D U-Net model with optional FiLM layers for style conditioning and a shrink factor.
It means that the output data is of size (N/shrink_factor, N/shrink_factor, N/shrink_factor), where N is the input size.
"""
super().__init__(N=N,
in_channels=in_channels,
out_channels=out_channels,
style_parameters=style_dim,
device=device)
import numpy as np
if depth is not None:
self.depth_enc = depth
if np.floor(np.log2(N)).astype(int) - 1 < depth:
raise ValueError(f"Depth {depth} is too large for input size {N}. Maximum depth is {np.floor(np.log2(N)).astype(int) - 1}.")
else:
self.depth_enc = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
self.depth_dec = self.depth_enc - shrink_factor_exponent # Depth of the U-Net based on input size N and shrink factor
self.first_layer_channel_exponent = first_layer_channel_exponent
self.shrink_factor_exponent = shrink_factor_exponent
self.enc = []
for i in range(self.depth_enc):
in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1)
out_ch = 2**(self.first_layer_channel_exponent + i)
self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim))
self.enc = nn.ModuleList(self.enc)
self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth_enc - 1),
2**(self.first_layer_channel_exponent + self.depth_enc), style_dim)
self.dec = []
for i in range(self.depth_enc - 1, self.depth_dec - self.depth_enc -1, -1):
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
out_ch = 2**(self.first_layer_channel_exponent + i)
skip_conn_ch = out_ch if i >= self.depth_dec-self.depth_enc else 0
self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim))
self.dec = nn.ModuleList(self.dec)
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent+shrink_factor_exponent), out_channels, kernel_size=1)
def forward(self, x, style):
out = x
outlist = []
for i in range(self.depth_enc):
skip, out = self.enc[i](out, style)
outlist.append(skip)
out = self.bottleneck(out, style)
for i in range(self.depth_dec):
if i < self.depth_enc:
out = self.dec[i](out, outlist[self.depth_enc - 1 - i], style)
else:
out = self.dec[i](out, None, style)
return self.final(out)