Fix and improve SRSGAN
This commit is contained in:
parent
4be71a32d1
commit
89e8651c26
@ -8,7 +8,7 @@ from .resample import Resampler
|
||||
|
||||
class G(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=True):
|
||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=False):
|
||||
super().__init__()
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
@ -34,9 +34,10 @@ class G(nn.Module):
|
||||
SkipBlock(prev_chan, next_chan, out_chan, cat_noise))
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
x = self.block0(x)
|
||||
|
||||
y = None
|
||||
#y = None
|
||||
for block in self.blocks:
|
||||
x, y = block(x, y)
|
||||
|
||||
@ -86,8 +87,7 @@ class SkipBlock(nn.Module):
|
||||
)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
AddNoise(cat_noise, chan=next_chan),
|
||||
nn.Conv3d(next_chan + int(cat_noise), out_chan, 1),
|
||||
nn.Conv3d(next_chan, out_chan, 1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
||||
@ -132,7 +132,7 @@ class AddNoise(nn.Module):
|
||||
|
||||
x = x + noise
|
||||
|
||||
return x + noise
|
||||
return x
|
||||
|
||||
|
||||
class D(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user