Change U-Net and V-Net inner naming convention; Slim down V-Net
This commit is contained in:
parent
01ff0aca37
commit
94ce018cb8
@ -8,37 +8,37 @@ class UNet(nn.Module):
|
|||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_0l = ConvBlock(in_channels, 64, seq='CAC')
|
self.conv_l0 = ConvBlock(in_channels, 64, seq='CAC')
|
||||||
self.down_0l = ConvBlock(64, seq='BADBA')
|
self.down_l0 = ConvBlock(64, seq='BADBA')
|
||||||
self.conv_1l = ConvBlock(64, seq='CBAC')
|
self.conv_l1 = ConvBlock(64, seq='CBAC')
|
||||||
self.down_1l = ConvBlock(64, seq='BADBA')
|
self.down_l1 = ConvBlock(64, seq='BADBA')
|
||||||
|
|
||||||
self.conv_2c = ConvBlock(64, seq='CBAC')
|
self.conv_c = ConvBlock(64, seq='CBAC')
|
||||||
|
|
||||||
self.up_1r = ConvBlock(64, seq='BAUBA')
|
self.up_r1 = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_1r = ConvBlock(128, 64, seq='CBAC')
|
self.conv_r1 = ConvBlock(128, 64, seq='CBAC')
|
||||||
self.up_0r = ConvBlock(64, seq='BAUBA')
|
self.up_r0 = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_0r = ConvBlock(128, out_channels, seq='CAC')
|
self.conv_r0 = ConvBlock(128, out_channels, seq='CAC')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y0 = self.conv_0l(x)
|
y0 = self.conv_l0(x)
|
||||||
x = self.down_0l(y0)
|
x = self.down_l0(y0)
|
||||||
|
|
||||||
y1 = self.conv_1l(x)
|
y1 = self.conv_l1(x)
|
||||||
x = self.down_1l(y1)
|
x = self.down_l1(y1)
|
||||||
|
|
||||||
x = self.conv_2c(x)
|
x = self.conv_c(x)
|
||||||
|
|
||||||
x = self.up_1r(x)
|
x = self.up_r1(x)
|
||||||
y1 = narrow_like(y1, x)
|
y1 = narrow_like(y1, x)
|
||||||
x = torch.cat([y1, x], dim=1)
|
x = torch.cat([y1, x], dim=1)
|
||||||
del y1
|
del y1
|
||||||
x = self.conv_1r(x)
|
x = self.conv_r1(x)
|
||||||
|
|
||||||
x = self.up_0r(x)
|
x = self.up_r0(x)
|
||||||
y0 = narrow_like(y0, x)
|
y0 = narrow_like(y0, x)
|
||||||
x = torch.cat([y0, x], dim=1)
|
x = torch.cat([y0, x], dim=1)
|
||||||
del y0
|
del y0
|
||||||
x = self.conv_0r(x)
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -8,52 +8,37 @@ class VNet(nn.Module):
|
|||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_0l = nn.Sequential(
|
self.conv_l0 = ResBlock(in_channels, 64, seq='CAC')
|
||||||
ConvBlock(in_channels, 64, seq='CA'),
|
self.down_l0 = ConvBlock(64, seq='BADBA')
|
||||||
ResBlock(64, seq='CBACBACBA'),
|
self.conv_l1 = ResBlock(64, seq='CBAC')
|
||||||
)
|
self.down_l1 = ConvBlock(64, seq='BADBA')
|
||||||
self.down_0l = ConvBlock(64, 128, seq='DBA')
|
|
||||||
self.conv_1l = nn.Sequential(
|
|
||||||
ResBlock(128, seq='CBACBA'),
|
|
||||||
ResBlock(128, seq='CBACBA'),
|
|
||||||
)
|
|
||||||
self.down_1l = ConvBlock(128, 256, seq='DBA')
|
|
||||||
|
|
||||||
self.conv_2c = nn.Sequential(
|
self.conv_c = ResBlock(64, seq='CBAC')
|
||||||
ResBlock(256, seq='CBACBA'),
|
|
||||||
ResBlock(256, seq='CBACBA'),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.up_1r = ConvBlock(256, 128, seq='UBA')
|
self.up_r1 = ConvBlock(64, seq='BAUBA')
|
||||||
self.conv_1r = nn.Sequential(
|
self.conv_r1 = ResBlock(128, 64, seq='CBAC')
|
||||||
ResBlock(256, seq='CBACBA'),
|
self.up_r0 = ConvBlock(64, seq='BAUBA')
|
||||||
ResBlock(256, seq='CBACBA'),
|
self.conv_r0 = ResBlock(128, out_channels, seq='CAC')
|
||||||
)
|
|
||||||
self.up_0r = ConvBlock(256, 64, seq='UBA')
|
|
||||||
self.conv_0r = nn.Sequential(
|
|
||||||
ResBlock(128, seq='CBACBACA'),
|
|
||||||
ConvBlock(128, out_channels, seq='C')
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y0 = self.conv_0l(x)
|
y0 = self.conv_l0(x)
|
||||||
x = self.down_0l(y0)
|
x = self.down_l0(y0)
|
||||||
|
|
||||||
y1 = self.conv_1l(x)
|
y1 = self.conv_l1(x)
|
||||||
x = self.down_1l(y1)
|
x = self.down_l1(y1)
|
||||||
|
|
||||||
x = self.conv_2c(x)
|
x = self.conv_c(x)
|
||||||
|
|
||||||
x = self.up_1r(x)
|
x = self.up_r1(x)
|
||||||
y1 = narrow_like(y1, x)
|
y1 = narrow_like(y1, x)
|
||||||
x = torch.cat([y1, x], dim=1)
|
x = torch.cat([y1, x], dim=1)
|
||||||
del y1
|
del y1
|
||||||
x = self.conv_1r(x)
|
x = self.conv_r1(x)
|
||||||
|
|
||||||
x = self.up_0r(x)
|
x = self.up_r0(x)
|
||||||
y0 = narrow_like(y0, x)
|
y0 = narrow_like(y0, x)
|
||||||
x = torch.cat([y0, x], dim=1)
|
x = torch.cat([y0, x], dim=1)
|
||||||
del y0
|
del y0
|
||||||
x = self.conv_0r(x)
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user