many improvements
This commit is contained in:
parent
c07ec8f8cf
commit
6c526d7115
4 changed files with 219 additions and 53 deletions
|
@ -30,6 +30,27 @@ class UNetBlock(nn.Module):
|
|||
x = self.film(x, style)
|
||||
return x
|
||||
|
||||
class UNetEncLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, style_dim=None):
|
||||
super(UNetEncLayer, self).__init__()
|
||||
self.block = UNetBlock(in_channels, out_channels, style_dim)
|
||||
self.pool = nn.MaxPool3d(2)
|
||||
|
||||
def forward(self, x, style=None):
|
||||
x = self.block(x, style)
|
||||
return x, self.pool(x)
|
||||
|
||||
class UNetDecLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, skip_connection_channels, style_dim=None):
|
||||
super(UNetDecLayer, self).__init__()
|
||||
self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
self.block = UNetBlock(out_channels + skip_connection_channels, out_channels, style_dim)
|
||||
|
||||
def forward(self, x, skip_connection, style=None):
|
||||
x = self.up(x)
|
||||
x = torch.cat([x, skip_connection], dim=1)
|
||||
return self.block(x, style)
|
||||
|
||||
class UNet3D(BaseModel):
|
||||
def __init__(self, N: int = 128,
|
||||
in_channels: int = 2,
|
||||
|
@ -54,23 +75,51 @@ class UNet3D(BaseModel):
|
|||
out_channels=out_channels,
|
||||
style_parameters=style_dim,
|
||||
device=device)
|
||||
import numpy as np
|
||||
|
||||
self.enc1 = UNetBlock(in_channels, 32, style_dim)
|
||||
self.pool1 = nn.MaxPool3d(2)
|
||||
self.enc2 = UNetBlock(32, 64, style_dim)
|
||||
self.pool2 = nn.MaxPool3d(2)
|
||||
self.bottleneck = UNetBlock(64, 128, style_dim)
|
||||
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
||||
self.first_layer_channel_exponent = 3
|
||||
|
||||
self.up2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
|
||||
self.dec2 = UNetBlock(128, 64)
|
||||
self.up1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
|
||||
self.dec1 = UNetBlock(64, 32)
|
||||
self.final = nn.Conv3d(32, out_channels, kernel_size=1)
|
||||
self.enc=[]
|
||||
|
||||
for i in range(self.depth):
|
||||
in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1)
|
||||
out_ch = 2**(self.first_layer_channel_exponent + i)
|
||||
self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim))
|
||||
|
||||
self.enc = nn.ModuleList(self.enc)
|
||||
|
||||
self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth - 1),
|
||||
2**(self.first_layer_channel_exponent + self.depth), style_dim)
|
||||
|
||||
self.dec=[]
|
||||
|
||||
for i in range(self.depth - 1, -1, -1):
|
||||
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
|
||||
out_ch = 2**(self.first_layer_channel_exponent + i)
|
||||
skip_conn_ch = out_ch
|
||||
self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim))
|
||||
|
||||
self.dec = nn.ModuleList(self.dec)
|
||||
|
||||
|
||||
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
|
||||
|
||||
|
||||
|
||||
def forward(self, x, style):
|
||||
e1 = self.enc1(x, style)
|
||||
e2 = self.enc2(self.pool1(e1), style)
|
||||
b = self.bottleneck(self.pool2(e2), style)
|
||||
d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
|
||||
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
|
||||
return self.final(d1)
|
||||
|
||||
out = x
|
||||
outlist = []
|
||||
|
||||
for i in range(self.depth):
|
||||
skip, out = self.enc[i](out, style)
|
||||
outlist.append(skip)
|
||||
|
||||
out = self.bottleneck(out, style)
|
||||
|
||||
for i in range(self.depth):
|
||||
out = self.dec[i](out, outlist[self.depth - 1 - i], style)
|
||||
|
||||
return self.final(out)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue