Replace narrow_like by narrow_by in UNet/VNet
This makes it traceable / scriptable. Note that the narrow_like in ResBlock used by VNet is not changed yet
This commit is contained in:
parent
c4ab7e065b
commit
39ad59436e
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock
|
from .conv import ConvBlock
|
||||||
from .narrow import narrow_like
|
from .narrow import narrow_by
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
@ -46,19 +46,19 @@ class UNet(nn.Module):
|
|||||||
x = self.conv_c(x)
|
x = self.conv_c(x)
|
||||||
|
|
||||||
x = self.up_r1(x)
|
x = self.up_r1(x)
|
||||||
y1 = narrow_like(y1, x)
|
y1 = narrow_by(y1, 4)
|
||||||
x = torch.cat([y1, x], dim=1)
|
x = torch.cat([y1, x], dim=1)
|
||||||
del y1
|
del y1
|
||||||
x = self.conv_r1(x)
|
x = self.conv_r1(x)
|
||||||
|
|
||||||
x = self.up_r0(x)
|
x = self.up_r0(x)
|
||||||
y0 = narrow_like(y0, x)
|
y0 = narrow_by(y0, 16)
|
||||||
x = torch.cat([y0, x], dim=1)
|
x = torch.cat([y0, x], dim=1)
|
||||||
del y0
|
del y0
|
||||||
x = self.conv_r0(x)
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
if self.bypass:
|
if self.bypass:
|
||||||
x0 = narrow_like(x0, x)
|
x0 = narrow_by(x0, 20)
|
||||||
x += x0
|
x += x0
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -2,7 +2,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .conv import ConvBlock, ResBlock
|
from .conv import ConvBlock, ResBlock
|
||||||
from .narrow import narrow_like
|
from .narrow import narrow_by
|
||||||
|
|
||||||
|
|
||||||
class VNet(nn.Module):
|
class VNet(nn.Module):
|
||||||
@ -50,19 +50,19 @@ class VNet(nn.Module):
|
|||||||
x = self.conv_c(x)
|
x = self.conv_c(x)
|
||||||
|
|
||||||
x = self.up_r1(x)
|
x = self.up_r1(x)
|
||||||
y1 = narrow_like(y1, x)
|
y1 = narrow_by(y1, 4)
|
||||||
x = torch.cat([y1, x], dim=1)
|
x = torch.cat([y1, x], dim=1)
|
||||||
del y1
|
del y1
|
||||||
x = self.conv_r1(x)
|
x = self.conv_r1(x)
|
||||||
|
|
||||||
x = self.up_r0(x)
|
x = self.up_r0(x)
|
||||||
y0 = narrow_like(y0, x)
|
y0 = narrow_by(y0, 16)
|
||||||
x = torch.cat([y0, x], dim=1)
|
x = torch.cat([y0, x], dim=1)
|
||||||
del y0
|
del y0
|
||||||
x = self.conv_r0(x)
|
x = self.conv_r0(x)
|
||||||
|
|
||||||
if self.bypass:
|
if self.bypass:
|
||||||
x0 = narrow_like(x0, x)
|
x0 = narrow_by(x0, 20)
|
||||||
x += x0
|
x += x0
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user