Replace narrow_like by narrow_by in UNet/VNet
This makes it traceable / scriptable. Note that the narrow_like in ResBlock used by VNet is not changed yet
This commit is contained in:
parent
c4ab7e065b
commit
39ad59436e
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import ConvBlock
|
||||
from .narrow import narrow_like
|
||||
from .narrow import narrow_by
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
@ -46,19 +46,19 @@ class UNet(nn.Module):
|
||||
x = self.conv_c(x)
|
||||
|
||||
x = self.up_r1(x)
|
||||
y1 = narrow_like(y1, x)
|
||||
y1 = narrow_by(y1, 4)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x)
|
||||
|
||||
x = self.up_r0(x)
|
||||
y0 = narrow_like(y0, x)
|
||||
y0 = narrow_by(y0, 16)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r0(x)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_like(x0, x)
|
||||
x0 = narrow_by(x0, 20)
|
||||
x += x0
|
||||
|
||||
return x
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import ConvBlock, ResBlock
|
||||
from .narrow import narrow_like
|
||||
from .narrow import narrow_by
|
||||
|
||||
|
||||
class VNet(nn.Module):
|
||||
@ -50,19 +50,19 @@ class VNet(nn.Module):
|
||||
x = self.conv_c(x)
|
||||
|
||||
x = self.up_r1(x)
|
||||
y1 = narrow_like(y1, x)
|
||||
y1 = narrow_by(y1, 4)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x)
|
||||
|
||||
x = self.up_r0(x)
|
||||
y0 = narrow_like(y0, x)
|
||||
y0 = narrow_by(y0, 16)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r0(x)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_like(x0, x)
|
||||
x0 = narrow_by(x0, 20)
|
||||
x += x0
|
||||
|
||||
return x
|
||||
|
Loading…
Reference in New Issue
Block a user