ChatGPT models
This commit is contained in:
parent
2b9830211e
commit
26af105195
4 changed files with 193 additions and 0 deletions
76
sCOCA_ML/models/UNet_models.py
Normal file
76
sCOCA_ML/models/UNet_models.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .base_class_models import BaseModel
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, num_features, style_dim):
|
||||
super(FiLM, self).__init__()
|
||||
self.gamma = nn.Linear(style_dim, num_features)
|
||||
self.beta = nn.Linear(style_dim, num_features)
|
||||
|
||||
def forward(self, x, style):
|
||||
gamma = self.gamma(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
beta = self.beta(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
||||
return gamma * x + beta
|
||||
|
||||
class UNetBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, style_dim=None):
|
||||
super(UNetBlock, self).__init__()
|
||||
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
|
||||
self.norm = nn.BatchNorm3d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.film = FiLM(out_channels, style_dim) if style_dim else None
|
||||
|
||||
def forward(self, x, style=None):
|
||||
x = self.relu(self.norm(self.conv1(x)))
|
||||
x = self.relu(self.norm(self.conv2(x)))
|
||||
if self.film:
|
||||
x = self.film(x, style)
|
||||
return x
|
||||
|
||||
class UNet3D(BaseModel):
|
||||
def __init__(self, N: int = 128,
|
||||
in_channels: int = 2,
|
||||
out_channels: int = 1,
|
||||
style_dim: int = 2,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
"""
|
||||
3D U-Net model with optional FiLM layers for style conditioning.
|
||||
Parameters:
|
||||
- N: Size of the input data: data will have the shape (B, C, N, N, N) with B the batch size, C the number of channels.
|
||||
- in_channels: Number of input channels (default is 2).
|
||||
- out_channels: Number of output channels (default is 1).
|
||||
- style_dim: Dimension of the style vector (default is 2).
|
||||
- device: Device to load the model onto (default is CPU).
|
||||
This model implements a 3D U-Net architecture with downsampling and upsampling blocks.
|
||||
The model uses convolutional layers with ReLU activations and batch normalization.
|
||||
The FiLM layers are used to condition the feature maps on style parameters.
|
||||
"""
|
||||
|
||||
super().init(N=N,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
style_parameters=style_dim,
|
||||
device=device)
|
||||
|
||||
self.enc1 = UNetBlock(in_channels, 32, style_dim)
|
||||
self.pool1 = nn.MaxPool3d(2)
|
||||
self.enc2 = UNetBlock(32, 64, style_dim)
|
||||
self.pool2 = nn.MaxPool3d(2)
|
||||
self.bottleneck = UNetBlock(64, 128, style_dim)
|
||||
|
||||
self.up2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
|
||||
self.dec2 = UNetBlock(128, 64)
|
||||
self.up1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
|
||||
self.dec1 = UNetBlock(64, 32)
|
||||
self.final = nn.Conv3d(32, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x, style):
|
||||
e1 = self.enc1(x, style)
|
||||
e2 = self.enc2(self.pool1(e1), style)
|
||||
b = self.bottleneck(self.pool2(e2), style)
|
||||
d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
|
||||
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
|
||||
return self.final(d1)
|
Loading…
Add table
Add a link
Reference in a new issue