diff --git a/map2map/models/unet.py b/map2map/models/unet.py index 106452f..bf78221 100644 --- a/map2map/models/unet.py +++ b/map2map/models/unet.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from .conv import ConvBlock -from .narrow import narrow_like +from .narrow import narrow_by class UNet(nn.Module): @@ -46,19 +46,19 @@ class UNet(nn.Module): x = self.conv_c(x) x = self.up_r1(x) - y1 = narrow_like(y1, x) + y1 = narrow_by(y1, 4) x = torch.cat([y1, x], dim=1) del y1 x = self.conv_r1(x) x = self.up_r0(x) - y0 = narrow_like(y0, x) + y0 = narrow_by(y0, 16) x = torch.cat([y0, x], dim=1) del y0 x = self.conv_r0(x) if self.bypass: - x0 = narrow_like(x0, x) + x0 = narrow_by(x0, 20) x += x0 return x diff --git a/map2map/models/vnet.py b/map2map/models/vnet.py index e2da5dd..f862da8 100644 --- a/map2map/models/vnet.py +++ b/map2map/models/vnet.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from .conv import ConvBlock, ResBlock -from .narrow import narrow_like +from .narrow import narrow_by class VNet(nn.Module): @@ -50,19 +50,19 @@ class VNet(nn.Module): x = self.conv_c(x) x = self.up_r1(x) - y1 = narrow_like(y1, x) + y1 = narrow_by(y1, 4) x = torch.cat([y1, x], dim=1) del y1 x = self.conv_r1(x) x = self.up_r0(x) - y0 = narrow_like(y0, x) + y0 = narrow_by(y0, 16) x = torch.cat([y0, x], dim=1) del y0 x = self.conv_r0(x) if self.bypass: - x0 = narrow_like(x0, x) + x0 = narrow_by(x0, 20) x += x0 return x