Change VNet based on experiment on displacement

This commit is contained in:
Yin Li 2020-08-17 19:28:10 -07:00
parent ebd962e333
commit 01cc1b6964
3 changed files with 59 additions and 74 deletions

View File

@ -1,5 +1,5 @@
from .unet import UNet from .unet import UNet
from .vnet import VNet, VNetFat from .vnet import VNet
from .patchgan import PatchGAN, PatchGAN42 from .patchgan import PatchGAN, PatchGAN42
from .narrow import narrow_by, narrow_cast, narrow_like from .narrow import narrow_by, narrow_cast, narrow_like
@ -7,8 +7,6 @@ from .resample import resample, Resampler
from .lag2eul import Lag2Eul from .lag2eul import Lag2Eul
from .lag2eul import Lag2Eul
from .dice import DiceLoss, dice_loss from .dice import DiceLoss, dice_loss
from .adversary import adv_model_wrapper, adv_criterion_wrapper from .adversary import adv_model_wrapper, adv_criterion_wrapper

View File

@ -6,22 +6,37 @@ from .narrow import narrow_like
class UNet(nn.Module): class UNet(nn.Module):
def __init__(self, in_chan, out_chan, **kwargs): def __init__(self, in_chan, out_chan, bypass=None, **kwargs):
"""U-Net like network
Note:
Global bypass connection adding the input to the output (similar to
COLA for displacement input and output) from Alvaro Sanchez Gonzalez.
Enabled by default when in_chan equals out_chan
Global bypass, under additive symmetry, effectively obviates --aug-add
"""
super().__init__() super().__init__()
self.conv_l0 = ConvBlock(in_chan, 64, seq='CAC') self.conv_l0 = ConvBlock(in_chan, 64, seq='CACBA')
self.down_l0 = ConvBlock(64, seq='BADBA') self.down_l0 = ConvBlock(64, seq='DBA')
self.conv_l1 = ConvBlock(64, seq='CBAC') self.conv_l1 = ConvBlock(64, seq='CBACBA')
self.down_l1 = ConvBlock(64, seq='BADBA') self.down_l1 = ConvBlock(64, seq='DBA')
self.conv_c = ConvBlock(64, seq='CBAC') self.conv_c = ConvBlock(64, seq='CBACBA')
self.up_r1 = ConvBlock(64, seq='BAUBA') self.up_r1 = ConvBlock(64, seq='UBA')
self.conv_r1 = ConvBlock(128, 64, seq='CBAC') self.conv_r1 = ConvBlock(128, 64, seq='CBACBA')
self.up_r0 = ConvBlock(64, seq='BAUBA') self.up_r0 = ConvBlock(64, seq='UBA')
self.conv_r0 = ConvBlock(128, out_chan, seq='CAC') self.conv_r0 = ConvBlock(128, out_chan, seq='CAC')
self.bypass = in_chan == out_chan
def forward(self, x): def forward(self, x):
if self.bypass:
x0 = x
y0 = self.conv_l0(x) y0 = self.conv_l0(x)
x = self.down_l0(y0) x = self.down_l0(y0)
@ -42,4 +57,8 @@ class UNet(nn.Module):
del y0 del y0
x = self.conv_r0(x) x = self.conv_r0(x)
if self.bypass:
x0 = narrow_like(x0, x)
x += x0
return x return x

View File

@ -6,77 +6,41 @@ from .narrow import narrow_like
class VNet(nn.Module): class VNet(nn.Module):
def __init__(self, in_chan, out_chan, **kwargs): def __init__(self, in_chan, out_chan, bypass=None, **kwargs):
"""V-Net like network
Note:
Global bypass connection adding the input to the output (similar to
COLA for displacement input and output) from Alvaro Sanchez Gonzalez.
Enabled by default when in_chan equals out_chan
Global bypass, under additive symmetry, effectively obviates --aug-add
Non-identity skip connection in residual blocks
"""
super().__init__() super().__init__()
self.conv_l0 = ResBlock(in_chan, 64, seq='CAC') # activate non-identity skip connection in residual block
self.down_l0 = ConvBlock(64, seq='BADBA') # by explicitly setting out_chan
self.conv_l1 = ResBlock(64, seq='CBAC') self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
self.down_l1 = ConvBlock(64, seq='BADBA')
self.conv_c = ResBlock(64, seq='CBAC')
self.up_r1 = ConvBlock(64, seq='BAUBA')
self.conv_r1 = ResBlock(128, 64, seq='CBAC')
self.up_r0 = ConvBlock(64, seq='BAUBA')
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
def forward(self, x):
y0 = self.conv_l0(x)
x = self.down_l0(y0)
y1 = self.conv_l1(x)
x = self.down_l1(y1)
x = self.conv_c(x)
x = self.up_r1(x)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x)
x = self.up_r0(x)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r0(x)
return x
class VNetFat(nn.Module):
def __init__(self, in_chan, out_chan, **kwargs):
super().__init__()
self.conv_l0 = nn.Sequential(
ResBlock(in_chan, 64, seq='CACBA'),
ResBlock(64, seq='CBACBA'),
)
self.down_l0 = ConvBlock(64, seq='DBA') self.down_l0 = ConvBlock(64, seq='DBA')
self.conv_l1 = nn.Sequential( self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
ResBlock(64, seq='CBACBA'),
ResBlock(64, seq='CBACBA'),
) # FIXME: test CBACBA+DBA vs CBAC+BADBA
self.down_l1 = ConvBlock(64, seq='DBA') self.down_l1 = ConvBlock(64, seq='DBA')
self.conv_c = nn.Sequential( self.conv_c = ResBlock(64, 64, seq='CBACBA')
ResBlock(64, seq='CBACBA'),
ResBlock(64, seq='CBACBA'),
)
self.up_r1 = ConvBlock(64, seq='UBA') self.up_r1 = ConvBlock(64, seq='UBA')
self.conv_r1 = nn.Sequential( self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
ResBlock(128, seq='CBACBA'), self.up_r0 = ConvBlock(64, seq='UBA')
ResBlock(128, seq='CBACBA'), self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
)
self.up_r0 = ConvBlock(128, 64, seq='UBA') self.bypass = in_chan == out_chan
self.conv_r0 = nn.Sequential(
ResBlock(128, seq='CBACBA'),
ResBlock(128, out_chan, seq='CAC'),
)
def forward(self, x): def forward(self, x):
if self.bypass:
x0 = x
y0 = self.conv_l0(x) y0 = self.conv_l0(x)
x = self.down_l0(y0) x = self.down_l0(y0)
@ -97,4 +61,8 @@ class VNetFat(nn.Module):
del y0 del y0
x = self.conv_r0(x) x = self.conv_r0(x)
if self.bypass:
x0 = narrow_like(x0, x)
x += x0
return x return x