From 89e8651c2655947d7d0b1beb9ac5f93652b36748 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 18 Mar 2021 15:23:33 -0400 Subject: [PATCH] Fix and improve SRSGAN --- map2map/models/srsgan.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/map2map/models/srsgan.py b/map2map/models/srsgan.py index c4fb096..ff0de9a 100644 --- a/map2map/models/srsgan.py +++ b/map2map/models/srsgan.py @@ -8,7 +8,7 @@ from .resample import Resampler class G(nn.Module): def __init__(self, in_chan, out_chan, scale_factor=16, - chan_base=512, chan_min=64, chan_max=512, cat_noise=True): + chan_base=512, chan_min=64, chan_max=512, cat_noise=False): super().__init__() self.scale_factor = scale_factor @@ -34,9 +34,10 @@ class G(nn.Module): SkipBlock(prev_chan, next_chan, out_chan, cat_noise)) def forward(self, x): + y = x x = self.block0(x) - y = None + #y = None for block in self.blocks: x, y = block(x, y) @@ -86,8 +87,7 @@ class SkipBlock(nn.Module): ) self.proj = nn.Sequential( - AddNoise(cat_noise, chan=next_chan), - nn.Conv3d(next_chan + int(cat_noise), out_chan, 1), + nn.Conv3d(next_chan, out_chan, 1), nn.LeakyReLU(0.2, True), ) @@ -132,7 +132,7 @@ class AddNoise(nn.Module): x = x + noise - return x + noise + return x class D(nn.Module):