From 01cc1b6964636516e71ebb15e6eaf627c2beb5dc Mon Sep 17 00:00:00 2001 From: Yin Li Date: Mon, 17 Aug 2020 19:28:10 -0700 Subject: [PATCH] Change VNet based on experiment on displacement --- map2map/models/__init__.py | 4 +- map2map/models/unet.py | 37 +++++++++++---- map2map/models/vnet.py | 92 +++++++++++++------------------------- 3 files changed, 59 insertions(+), 74 deletions(-) diff --git a/map2map/models/__init__.py b/map2map/models/__init__.py index 4b00f93..8017b7a 100644 --- a/map2map/models/__init__.py +++ b/map2map/models/__init__.py @@ -1,5 +1,5 @@ from .unet import UNet -from .vnet import VNet, VNetFat +from .vnet import VNet from .patchgan import PatchGAN, PatchGAN42 from .narrow import narrow_by, narrow_cast, narrow_like @@ -7,8 +7,6 @@ from .resample import resample, Resampler from .lag2eul import Lag2Eul -from .lag2eul import Lag2Eul - from .dice import DiceLoss, dice_loss from .adversary import adv_model_wrapper, adv_criterion_wrapper diff --git a/map2map/models/unet.py b/map2map/models/unet.py index b130e42..106452f 100644 --- a/map2map/models/unet.py +++ b/map2map/models/unet.py @@ -6,22 +6,37 @@ from .narrow import narrow_like class UNet(nn.Module): - def __init__(self, in_chan, out_chan, **kwargs): + def __init__(self, in_chan, out_chan, bypass=None, **kwargs): + """U-Net like network + + Note: + + Global bypass connection adding the input to the output (similar to + COLA for displacement input and output) from Alvaro Sanchez Gonzalez. + Enabled by default when in_chan equals out_chan + + Global bypass, under additive symmetry, effectively obviates --aug-add + """ super().__init__() - self.conv_l0 = ConvBlock(in_chan, 64, seq='CAC') - self.down_l0 = ConvBlock(64, seq='BADBA') - self.conv_l1 = ConvBlock(64, seq='CBAC') - self.down_l1 = ConvBlock(64, seq='BADBA') + self.conv_l0 = ConvBlock(in_chan, 64, seq='CACBA') + self.down_l0 = ConvBlock(64, seq='DBA') + self.conv_l1 = ConvBlock(64, seq='CBACBA') + self.down_l1 = ConvBlock(64, seq='DBA') - self.conv_c = ConvBlock(64, seq='CBAC') + self.conv_c = ConvBlock(64, seq='CBACBA') - self.up_r1 = ConvBlock(64, seq='BAUBA') - self.conv_r1 = ConvBlock(128, 64, seq='CBAC') - self.up_r0 = ConvBlock(64, seq='BAUBA') + self.up_r1 = ConvBlock(64, seq='UBA') + self.conv_r1 = ConvBlock(128, 64, seq='CBACBA') + self.up_r0 = ConvBlock(64, seq='UBA') self.conv_r0 = ConvBlock(128, out_chan, seq='CAC') + self.bypass = in_chan == out_chan + def forward(self, x): + if self.bypass: + x0 = x + y0 = self.conv_l0(x) x = self.down_l0(y0) @@ -42,4 +57,8 @@ class UNet(nn.Module): del y0 x = self.conv_r0(x) + if self.bypass: + x0 = narrow_like(x0, x) + x += x0 + return x diff --git a/map2map/models/vnet.py b/map2map/models/vnet.py index eb08f1b..e2da5dd 100644 --- a/map2map/models/vnet.py +++ b/map2map/models/vnet.py @@ -6,77 +6,41 @@ from .narrow import narrow_like class VNet(nn.Module): - def __init__(self, in_chan, out_chan, **kwargs): + def __init__(self, in_chan, out_chan, bypass=None, **kwargs): + """V-Net like network + + Note: + + Global bypass connection adding the input to the output (similar to + COLA for displacement input and output) from Alvaro Sanchez Gonzalez. + Enabled by default when in_chan equals out_chan + + Global bypass, under additive symmetry, effectively obviates --aug-add + + Non-identity skip connection in residual blocks + """ super().__init__() - self.conv_l0 = ResBlock(in_chan, 64, seq='CAC') - self.down_l0 = ConvBlock(64, seq='BADBA') - self.conv_l1 = ResBlock(64, seq='CBAC') - self.down_l1 = ConvBlock(64, seq='BADBA') - - self.conv_c = ResBlock(64, seq='CBAC') - - self.up_r1 = ConvBlock(64, seq='BAUBA') - self.conv_r1 = ResBlock(128, 64, seq='CBAC') - self.up_r0 = ConvBlock(64, seq='BAUBA') - self.conv_r0 = ResBlock(128, out_chan, seq='CAC') - - def forward(self, x): - y0 = self.conv_l0(x) - x = self.down_l0(y0) - - y1 = self.conv_l1(x) - x = self.down_l1(y1) - - x = self.conv_c(x) - - x = self.up_r1(x) - y1 = narrow_like(y1, x) - x = torch.cat([y1, x], dim=1) - del y1 - x = self.conv_r1(x) - - x = self.up_r0(x) - y0 = narrow_like(y0, x) - x = torch.cat([y0, x], dim=1) - del y0 - x = self.conv_r0(x) - - return x - - -class VNetFat(nn.Module): - def __init__(self, in_chan, out_chan, **kwargs): - super().__init__() - - self.conv_l0 = nn.Sequential( - ResBlock(in_chan, 64, seq='CACBA'), - ResBlock(64, seq='CBACBA'), - ) + # activate non-identity skip connection in residual block + # by explicitly setting out_chan + self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA') self.down_l0 = ConvBlock(64, seq='DBA') - self.conv_l1 = nn.Sequential( - ResBlock(64, seq='CBACBA'), - ResBlock(64, seq='CBACBA'), - ) # FIXME: test CBACBA+DBA vs CBAC+BADBA + self.conv_l1 = ResBlock(64, 64, seq='CBACBA') self.down_l1 = ConvBlock(64, seq='DBA') - self.conv_c = nn.Sequential( - ResBlock(64, seq='CBACBA'), - ResBlock(64, seq='CBACBA'), - ) + self.conv_c = ResBlock(64, 64, seq='CBACBA') self.up_r1 = ConvBlock(64, seq='UBA') - self.conv_r1 = nn.Sequential( - ResBlock(128, seq='CBACBA'), - ResBlock(128, seq='CBACBA'), - ) - self.up_r0 = ConvBlock(128, 64, seq='UBA') - self.conv_r0 = nn.Sequential( - ResBlock(128, seq='CBACBA'), - ResBlock(128, out_chan, seq='CAC'), - ) + self.conv_r1 = ResBlock(128, 64, seq='CBACBA') + self.up_r0 = ConvBlock(64, seq='UBA') + self.conv_r0 = ResBlock(128, out_chan, seq='CAC') + + self.bypass = in_chan == out_chan def forward(self, x): + if self.bypass: + x0 = x + y0 = self.conv_l0(x) x = self.down_l0(y0) @@ -97,4 +61,8 @@ class VNetFat(nn.Module): del y0 x = self.conv_r0(x) + if self.bypass: + x0 = narrow_like(x0, x) + x += x0 + return x