Replace narrow_like by narrow_by in UNet/VNet

This makes it traceable / scriptable.
Note that the narrow_like in ResBlock used by VNet is not changed yet
This commit is contained in:
Yin Li 2020-09-11 00:23:18 -04:00
parent c4ab7e065b
commit 39ad59436e
2 changed files with 8 additions and 8 deletions

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .conv import ConvBlock from .conv import ConvBlock
from .narrow import narrow_like from .narrow import narrow_by
class UNet(nn.Module): class UNet(nn.Module):
@ -46,19 +46,19 @@ class UNet(nn.Module):
x = self.conv_c(x) x = self.conv_c(x)
x = self.up_r1(x) x = self.up_r1(x)
y1 = narrow_like(y1, x) y1 = narrow_by(y1, 4)
x = torch.cat([y1, x], dim=1) x = torch.cat([y1, x], dim=1)
del y1 del y1
x = self.conv_r1(x) x = self.conv_r1(x)
x = self.up_r0(x) x = self.up_r0(x)
y0 = narrow_like(y0, x) y0 = narrow_by(y0, 16)
x = torch.cat([y0, x], dim=1) x = torch.cat([y0, x], dim=1)
del y0 del y0
x = self.conv_r0(x) x = self.conv_r0(x)
if self.bypass: if self.bypass:
x0 = narrow_like(x0, x) x0 = narrow_by(x0, 20)
x += x0 x += x0
return x return x

View File

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from .conv import ConvBlock, ResBlock from .conv import ConvBlock, ResBlock
from .narrow import narrow_like from .narrow import narrow_by
class VNet(nn.Module): class VNet(nn.Module):
@ -50,19 +50,19 @@ class VNet(nn.Module):
x = self.conv_c(x) x = self.conv_c(x)
x = self.up_r1(x) x = self.up_r1(x)
y1 = narrow_like(y1, x) y1 = narrow_by(y1, 4)
x = torch.cat([y1, x], dim=1) x = torch.cat([y1, x], dim=1)
del y1 del y1
x = self.conv_r1(x) x = self.conv_r1(x)
x = self.up_r0(x) x = self.up_r0(x)
y0 = narrow_like(y0, x) y0 = narrow_by(y0, 16)
x = torch.cat([y0, x], dim=1) x = torch.cat([y0, x], dim=1)
del y0 del y0
x = self.conv_r0(x) x = self.conv_r0(x)
if self.bypass: if self.bypass:
x0 = narrow_like(x0, x) x0 = narrow_by(x0, 20)
x += x0 x += x0
return x return x