Add pyramid network

This commit is contained in:
Yin Li 2020-02-05 20:20:17 -05:00
parent a5c48e71b0
commit 679f9f2545
2 changed files with 46 additions and 0 deletions

View File

@ -1,5 +1,6 @@
from .unet import UNet
from .vnet import VNet, VNetFat
from .pyramid import PyramidNet
from .patchgan import PatchGAN
from .conv import narrow_like

45
map2map/models/pyramid.py Normal file
View File

@ -0,0 +1,45 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like
class PyramidNet(nn.Module):
def __init__(self, in_chan, out_chan):
super().__init__()
self.down = nn.AvgPool3d(2, stride=2)
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv_l0 = ResBlock(in_chan, 64, seq='CAC')
self.conv_l1 = ResBlock(64, seq='CBAC')
self.conv_c = ResBlock(64, seq='CBAC')
self.conv_r1 = ResBlock(128, 64, seq='CBAC')
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
def forward(self, x):
y0 = self.conv_l0(x)
x = self.down(y0)
y0 = y0 - self.up(x)
y1 = self.conv_l1(x)
x = self.down(y1)
y1 = y1 - self.up(x)
x = self.conv_c(x)
x = self.up(x)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x)
x = self.up(x)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r0(x)
return x