Fix and improve SRSGAN

This commit is contained in:
Yin Li 2021-03-18 15:23:33 -04:00
parent 4be71a32d1
commit 89e8651c26

View File

@ -8,7 +8,7 @@ from .resample import Resampler
class G(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512, cat_noise=True):
chan_base=512, chan_min=64, chan_max=512, cat_noise=False):
super().__init__()
self.scale_factor = scale_factor
@ -34,9 +34,10 @@ class G(nn.Module):
SkipBlock(prev_chan, next_chan, out_chan, cat_noise))
def forward(self, x):
y = x
x = self.block0(x)
y = None
#y = None
for block in self.blocks:
x, y = block(x, y)
@ -86,8 +87,7 @@ class SkipBlock(nn.Module):
)
self.proj = nn.Sequential(
AddNoise(cat_noise, chan=next_chan),
nn.Conv3d(next_chan + int(cat_noise), out_chan, 1),
nn.Conv3d(next_chan, out_chan, 1),
nn.LeakyReLU(0.2, True),
)
@ -132,7 +132,7 @@ class AddNoise(nn.Module):
x = x + noise
return x + noise
return x
class D(nn.Module):