diff --git a/map2map/models/style.py b/map2map/models/style.py index 94fc93e..2c1c80c 100644 --- a/map2map/models/style.py +++ b/map2map/models/style.py @@ -158,7 +158,7 @@ class ConvStyled3d(nn.Module): x = x.reshape(1, N * Cin, *DHWin) x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N) _, _, *DHWout = x.shape - x = x.reshape(N, Cout, *DHWout) + x = x.reshape(N, Cout, *DHWout) return x