Duplicate vnet.py for future modification
To incorporate styled convolution from style.py
This commit is contained in:
parent
a697845933
commit
f5bd657625
60
map2map/models/styled_vnet.py
Normal file
60
map2map/models/styled_vnet.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .conv import ConvBlock, ResBlock
|
||||||
|
from .narrow import narrow_by
|
||||||
|
|
||||||
|
|
||||||
|
class StyledVNet(nn.Module):
|
||||||
|
def __init__(self, in_chan, out_chan, bypass=None, **kwargs):
|
||||||
|
"""V-Net like network with styles
|
||||||
|
|
||||||
|
See `vnet.VNet`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# activate non-identity skip connection in residual block
|
||||||
|
# by explicitly setting out_chan
|
||||||
|
self.conv_l0 = ResBlock(in_chan, 64, seq='CACBA')
|
||||||
|
self.down_l0 = ConvBlock(64, seq='DBA')
|
||||||
|
self.conv_l1 = ResBlock(64, 64, seq='CBACBA')
|
||||||
|
self.down_l1 = ConvBlock(64, seq='DBA')
|
||||||
|
|
||||||
|
self.conv_c = ResBlock(64, 64, seq='CBACBA')
|
||||||
|
|
||||||
|
self.up_r1 = ConvBlock(64, seq='UBA')
|
||||||
|
self.conv_r1 = ResBlock(128, 64, seq='CBACBA')
|
||||||
|
self.up_r0 = ConvBlock(64, seq='UBA')
|
||||||
|
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
|
||||||
|
|
||||||
|
self.bypass = in_chan == out_chan
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.bypass:
|
||||||
|
x0 = x
|
||||||
|
|
||||||
|
y0 = self.conv_l0(x)
|
||||||
|
x = self.down_l0(y0)
|
||||||
|
|
||||||
|
y1 = self.conv_l1(x)
|
||||||
|
x = self.down_l1(y1)
|
||||||
|
|
||||||
|
x = self.conv_c(x)
|
||||||
|
|
||||||
|
x = self.up_r1(x)
|
||||||
|
y1 = narrow_by(y1, 4)
|
||||||
|
x = torch.cat([y1, x], dim=1)
|
||||||
|
del y1
|
||||||
|
x = self.conv_r1(x)
|
||||||
|
|
||||||
|
x = self.up_r0(x)
|
||||||
|
y0 = narrow_by(y0, 16)
|
||||||
|
x = torch.cat([y0, x], dim=1)
|
||||||
|
del y0
|
||||||
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
|
if self.bypass:
|
||||||
|
x0 = narrow_by(x0, 20)
|
||||||
|
x += x0
|
||||||
|
|
||||||
|
return x
|
Loading…
Reference in New Issue
Block a user