Fix and improve SRSGAN
This commit is contained in:
parent
4be71a32d1
commit
89e8651c26
@ -8,7 +8,7 @@ from .resample import Resampler
|
|||||||
|
|
||||||
class G(nn.Module):
|
class G(nn.Module):
|
||||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=True):
|
chan_base=512, chan_min=64, chan_max=512, cat_noise=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -34,9 +34,10 @@ class G(nn.Module):
|
|||||||
SkipBlock(prev_chan, next_chan, out_chan, cat_noise))
|
SkipBlock(prev_chan, next_chan, out_chan, cat_noise))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
x = self.block0(x)
|
x = self.block0(x)
|
||||||
|
|
||||||
y = None
|
#y = None
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x, y = block(x, y)
|
x, y = block(x, y)
|
||||||
|
|
||||||
@ -86,8 +87,7 @@ class SkipBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.proj = nn.Sequential(
|
self.proj = nn.Sequential(
|
||||||
AddNoise(cat_noise, chan=next_chan),
|
nn.Conv3d(next_chan, out_chan, 1),
|
||||||
nn.Conv3d(next_chan + int(cat_noise), out_chan, 1),
|
|
||||||
nn.LeakyReLU(0.2, True),
|
nn.LeakyReLU(0.2, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ class AddNoise(nn.Module):
|
|||||||
|
|
||||||
x = x + noise
|
x = x + noise
|
||||||
|
|
||||||
return x + noise
|
return x
|
||||||
|
|
||||||
|
|
||||||
class D(nn.Module):
|
class D(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user