Fix and improve SRSGAN

This commit is contained in:
Yin Li 2021-03-18 15:23:33 -04:00
parent 4be71a32d1
commit 89e8651c26

View File

@ -8,7 +8,7 @@ from .resample import Resampler
class G(nn.Module): class G(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16, def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512, cat_noise=True): chan_base=512, chan_min=64, chan_max=512, cat_noise=False):
super().__init__() super().__init__()
self.scale_factor = scale_factor self.scale_factor = scale_factor
@ -34,9 +34,10 @@ class G(nn.Module):
SkipBlock(prev_chan, next_chan, out_chan, cat_noise)) SkipBlock(prev_chan, next_chan, out_chan, cat_noise))
def forward(self, x): def forward(self, x):
y = x
x = self.block0(x) x = self.block0(x)
y = None #y = None
for block in self.blocks: for block in self.blocks:
x, y = block(x, y) x, y = block(x, y)
@ -86,8 +87,7 @@ class SkipBlock(nn.Module):
) )
self.proj = nn.Sequential( self.proj = nn.Sequential(
AddNoise(cat_noise, chan=next_chan), nn.Conv3d(next_chan, out_chan, 1),
nn.Conv3d(next_chan + int(cat_noise), out_chan, 1),
nn.LeakyReLU(0.2, True), nn.LeakyReLU(0.2, True),
) )
@ -132,7 +132,7 @@ class AddNoise(nn.Module):
x = x + noise x = x + noise
return x + noise return x
class D(nn.Module): class D(nn.Module):