Change VNet based on experiment on displacement
This commit is contained in:
parent
ebd962e333
commit
01cc1b6964
@ -1,5 +1,5 @@
|
|||||||
from .unet import UNet
|
from .unet import UNet
|
||||||
from .vnet import VNet, VNetFat
|
from .vnet import VNet
|
||||||
from .patchgan import PatchGAN, PatchGAN42
|
from .patchgan import PatchGAN, PatchGAN42
|
||||||
|
|
||||||
from .narrow import narrow_by, narrow_cast, narrow_like
|
from .narrow import narrow_by, narrow_cast, narrow_like
|
||||||
@ -7,8 +7,6 @@ from .resample import resample, Resampler
|
|||||||
|
|
||||||
from .lag2eul import Lag2Eul
|
from .lag2eul import Lag2Eul
|
||||||
|
|
||||||
from .lag2eul import Lag2Eul
|
|
||||||
|
|
||||||
from .dice import DiceLoss, dice_loss
|
from .dice import DiceLoss, dice_loss
|
||||||
|
|
||||||
from .adversary import adv_model_wrapper, adv_criterion_wrapper
|
from .adversary import adv_model_wrapper, adv_criterion_wrapper
|
||||||
|
@ -6,22 +6,37 @@ from .narrow import narrow_like
|
|||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, **kwargs):
|
def __init__(self, in_chan, out_chan, bypass=None, **kwargs):
|
||||||
|
"""U-Net like network
|
||||||
|
|
||||||
|
Note:
|
||||||
|
|
||||||
|
Global bypass connection adding the input to the output (similar to
|
||||||
|
COLA for displacement input and output) from Alvaro Sanchez Gonzalez.
|
||||||
|
Enabled by default when in_chan equals out_chan
|
||||||
|
|
||||||
|
Global bypass, under additive symmetry, effectively obviates --aug-add
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_l0 = ConvBlock(in_chan, 64, seq='CAC')
|
self.conv_l0 = ConvBlock(in_chan, 64, seq='CACBA')
|
||||||
self.down_l0 = ConvBlock(64, seq='BADBA')
|
self.down_l0 = ConvBlock(64, seq='DBA')
|
||||||
self.conv_l1 = ConvBlock(64, seq='CBAC')
|
self.conv_l1 = ConvBlock(64, seq='CBACBA')
|
||||||
self.down_l1 = ConvBlock(64, seq='BADBA')
|
self.down_l1 = ConvBlock(64, seq='DBA')
|
||||||
|
|
||||||
self.conv_c = ConvBlock(64, seq='CBAC')
|
self.conv_c = ConvBlock(64, seq='CBACBA')
|
||||||
|
|
||||||
self.up_r1 = ConvBlock(64, seq='BAUBA')
|
self.up_r1 = ConvBlock(64, seq='UBA')
|
||||||
self.conv_r1 = ConvBlock(128, 64, seq='CBAC')
|
self.conv_r1 = ConvBlock(128, 64, seq='CBACBA')
|
||||||
self.up_r0 = ConvBlock(64, seq='BAUBA')
|
self.up_r0 = ConvBlock(64, seq='UBA')
|
||||||
self.conv_r0 = ConvBlock(128, out_chan, seq='CAC')
|
self.conv_r0 = ConvBlock(128, out_chan, seq='CAC')
|
||||||
|
|
||||||
|
self.bypass = in_chan == out_chan
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if self.bypass:
|
||||||
|
x0 = x
|
||||||
|
|
||||||
y0 = self.conv_l0(x)
|
y0 = self.conv_l0(x)
|
||||||
x = self.down_l0(y0)
|
x = self.down_l0(y0)
|
||||||
|
|
||||||
@ -42,4 +57,8 @@ class UNet(nn.Module):
|
|||||||
del y0
|
del y0
|
||||||
x = self.conv_r0(x)
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
|
if self.bypass:
|
||||||
|
x0 = narrow_like(x0, x)
|
||||||
|
x += x0
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -6,77 +6,41 @@ from .narrow import narrow_like
|
|||||||
|
|
||||||
|
|
||||||
class VNet(nn.Module):
|
class VNet(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, **kwargs):
|
def __init__(self, in_chan, out_chan, bypass=None, **kwargs):
|
||||||
|
"""V-Net like network
|
||||||
|
|
||||||
|
Note:
|
||||||
|
|
||||||
|
Global bypass connection adding the input to the output (similar to
|
||||||
|
COLA for displacement input and output) from Alvaro Sanchez Gonzalez.
|
||||||
|
Enabled by default when in_chan equals out_chan
|
||||||
|
|
||||||
|
Global bypass, under additive symmetry, effectively obviates --aug-add
|
||||||
|
|
||||||
|
Non-identity skip connection in residual blocks
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv_l0 = ResBlock(in_chan, 64, seq='CAC')
|
# activate non-identity skip connection in residual block
|
||||||
self.down_l0 = ConvBlock(64, seq='BADBA')
|
# by explicitly setting out_chan
|
||||||
self.conv_l1 = ResBlock(64, seq='CBAC')
|
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
|
||||||
self.down_l1 = ConvBlock(64, seq='BADBA')
|
|
||||||
|
|
||||||
self.conv_c = ResBlock(64, seq='CBAC')
|
|
||||||
|
|
||||||
self.up_r1 = ConvBlock(64, seq='BAUBA')
|
|
||||||
self.conv_r1 = ResBlock(128, 64, seq='CBAC')
|
|
||||||
self.up_r0 = ConvBlock(64, seq='BAUBA')
|
|
||||||
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
y0 = self.conv_l0(x)
|
|
||||||
x = self.down_l0(y0)
|
|
||||||
|
|
||||||
y1 = self.conv_l1(x)
|
|
||||||
x = self.down_l1(y1)
|
|
||||||
|
|
||||||
x = self.conv_c(x)
|
|
||||||
|
|
||||||
x = self.up_r1(x)
|
|
||||||
y1 = narrow_like(y1, x)
|
|
||||||
x = torch.cat([y1, x], dim=1)
|
|
||||||
del y1
|
|
||||||
x = self.conv_r1(x)
|
|
||||||
|
|
||||||
x = self.up_r0(x)
|
|
||||||
y0 = narrow_like(y0, x)
|
|
||||||
x = torch.cat([y0, x], dim=1)
|
|
||||||
del y0
|
|
||||||
x = self.conv_r0(x)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class VNetFat(nn.Module):
|
|
||||||
def __init__(self, in_chan, out_chan, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv_l0 = nn.Sequential(
|
|
||||||
ResBlock(in_chan, 64, seq='CACBA'),
|
|
||||||
ResBlock(64, seq='CBACBA'),
|
|
||||||
)
|
|
||||||
self.down_l0 = ConvBlock(64, seq='DBA')
|
self.down_l0 = ConvBlock(64, seq='DBA')
|
||||||
self.conv_l1 = nn.Sequential(
|
self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
|
||||||
ResBlock(64, seq='CBACBA'),
|
|
||||||
ResBlock(64, seq='CBACBA'),
|
|
||||||
) # FIXME: test CBACBA+DBA vs CBAC+BADBA
|
|
||||||
self.down_l1 = ConvBlock(64, seq='DBA')
|
self.down_l1 = ConvBlock(64, seq='DBA')
|
||||||
|
|
||||||
self.conv_c = nn.Sequential(
|
self.conv_c = ResBlock(64, 64, seq='CBACBA')
|
||||||
ResBlock(64, seq='CBACBA'),
|
|
||||||
ResBlock(64, seq='CBACBA'),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.up_r1 = ConvBlock(64, seq='UBA')
|
self.up_r1 = ConvBlock(64, seq='UBA')
|
||||||
self.conv_r1 = nn.Sequential(
|
self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
|
||||||
ResBlock(128, seq='CBACBA'),
|
self.up_r0 = ConvBlock(64, seq='UBA')
|
||||||
ResBlock(128, seq='CBACBA'),
|
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
||||||
)
|
|
||||||
self.up_r0 = ConvBlock(128, 64, seq='UBA')
|
self.bypass = in_chan == out_chan
|
||||||
self.conv_r0 = nn.Sequential(
|
|
||||||
ResBlock(128, seq='CBACBA'),
|
|
||||||
ResBlock(128, out_chan, seq='CAC'),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if self.bypass:
|
||||||
|
x0 = x
|
||||||
|
|
||||||
y0 = self.conv_l0(x)
|
y0 = self.conv_l0(x)
|
||||||
x = self.down_l0(y0)
|
x = self.down_l0(y0)
|
||||||
|
|
||||||
@ -97,4 +61,8 @@ class VNetFat(nn.Module):
|
|||||||
del y0
|
del y0
|
||||||
x = self.conv_r0(x)
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
|
if self.bypass:
|
||||||
|
x0 = narrow_like(x0, x)
|
||||||
|
x += x0
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user