Add pyramid network

This commit is contained in:
Yin Li 2020-02-05 20:20:17 -05:00
parent a5c48e71b0
commit 679f9f2545
2 changed files with 46 additions and 0 deletions

View File

@ -1,5 +1,6 @@
from .unet import UNet from .unet import UNet
from .vnet import VNet, VNetFat from .vnet import VNet, VNetFat
from .pyramid import PyramidNet
from .patchgan import PatchGAN from .patchgan import PatchGAN
from .conv import narrow_like from .conv import narrow_like

45
map2map/models/pyramid.py Normal file
View File

@ -0,0 +1,45 @@
import torch
import torch.nn as nn
from .conv import ConvBlock, ResBlock, narrow_like
class PyramidNet(nn.Module):
def __init__(self, in_chan, out_chan):
super().__init__()
self.down = nn.AvgPool3d(2, stride=2)
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.conv_l0 = ResBlock(in_chan, 64, seq='CAC')
self.conv_l1 = ResBlock(64, seq='CBAC')
self.conv_c = ResBlock(64, seq='CBAC')
self.conv_r1 = ResBlock(128, 64, seq='CBAC')
self.conv_r0 = ResBlock(128, out_chan, seq='CAC')
def forward(self, x):
y0 = self.conv_l0(x)
x = self.down(y0)
y0 = y0 - self.up(x)
y1 = self.conv_l1(x)
x = self.down(y1)
y1 = y1 - self.up(x)
x = self.conv_c(x)
x = self.up(x)
y1 = narrow_like(y1, x)
x = torch.cat([y1, x], dim=1)
del y1
x = self.conv_r1(x)
x = self.up(x)
y0 = narrow_like(y0, x)
x = torch.cat([y0, x], dim=1)
del y0
x = self.conv_r0(x)
return x