Add srsgan model and scale_factor to model arguments

This commit is contained in:
Yin Li 2020-08-04 16:46:43 -04:00
parent 6dfde5ee7f
commit 265587922d
8 changed files with 229 additions and 10 deletions

View file

@ -5,8 +5,8 @@ import torch.nn as nn
def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges.
"""
ind = [slice(None)] * 2 + [slice(c, -c)] * (a.dim() - 2)
return a[tuple(ind)]
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
return a[ind]
def narrow_cast(*tensors):

View file

@ -4,7 +4,7 @@ from .conv import ConvBlock
class PatchGAN(nn.Module):
def __init__(self, in_chan, out_chan=1):
def __init__(self, in_chan, out_chan=1, **kwargs):
super().__init__()
self.convs = nn.Sequential(
@ -21,7 +21,7 @@ class PatchGAN(nn.Module):
class PatchGAN42(nn.Module):
"""PatchGAN similar to the one in pix2pix
"""
def __init__(self, in_chan, out_chan=1):
def __init__(self, in_chan, out_chan=1, **kwargs):
super().__init__()
self.convs = nn.Sequential(

View file

@ -5,7 +5,7 @@ from .conv import ConvBlock, ResBlock, narrow_like
class PyramidNet(nn.Module):
def __init__(self, in_chan, out_chan):
def __init__(self, in_chan, out_chan, **kwargs):
super().__init__()
self.down = nn.AvgPool3d(2, stride=2)

218
map2map/models/srsgan.py Normal file
View file

@ -0,0 +1,218 @@
from math import log2
import torch
import torch.nn as nn
from .narrow import narrow_by, narrow_like
from .resample import Resampler
class G(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512, cat_noise=True):
super().__init__()
self.scale_factor = scale_factor
num_blocks = round(log2(self.scale_factor))
assert chan_min <= chan_max
def chan(b):
c = chan_base >> b
c = max(c, chan_min)
c = min(c, chan_max)
return c
self.block0 = nn.Sequential(
nn.Conv3d(in_chan, chan(0), 1),
nn.LeakyReLU(0.2, True),
)
self.blocks = nn.ModuleList()
for b in range(num_blocks):
prev_chan, next_chan = chan(b), chan(b+1)
self.blocks.append(
SkipBlock(prev_chan, next_chan, out_chan, cat_noise))
def forward(self, x):
x = self.block0(x)
y = None
for block in self.blocks:
x, y = block(x, y)
return y
class SkipBlock(nn.Module):
"""The "I" block of the StyleGAN2 generator.
x_p y_p
| |
convolution linear upsample
| |
>--- projection ------>+
| |
v v
x_n y_n
See Fig. 7 (b) upper in https://arxiv.org/abs/1912.04958
Upsampling are all linear, not transposed convolution.
Parameters
----------
prev_chan : number of channels of x_p
next_chan : number of channels of x_n
out_chan : number of channels of y_p and y_n
cat_noise: concatenate noise if True, otherwise add noise
Notes
-----
next_size = 2 * prev_size - 6
"""
def __init__(self, prev_chan, next_chan, out_chan, cat_noise):
super().__init__()
self.upsample = Resampler(3, 2)
self.conv = nn.Sequential(
AddNoise(cat_noise, chan=prev_chan),
self.upsample,
nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3),
nn.LeakyReLU(0.2, True),
AddNoise(cat_noise, chan=next_chan),
nn.Conv3d(next_chan + int(cat_noise), next_chan, 3),
nn.LeakyReLU(0.2, True),
)
self.proj = nn.Sequential(
AddNoise(cat_noise, chan=next_chan),
nn.Conv3d(next_chan + int(cat_noise), out_chan, 1),
nn.LeakyReLU(0.2, True),
)
def forward(self, x, y):
x = self.conv(x) # narrow by 3
if y is None:
y = self.proj(x)
else:
y = self.upsample(y) # narrow by 1
y = narrow_by(y, 2)
y = y + self.proj(x)
return x, y
class AddNoise(nn.Module):
"""Add or concatenate noise.
Add noise if `cat=False`.
The number of channels `chan` should be 1 (StyleGAN2)
or that of the input (StyleGAN).
"""
def __init__(self, cat, chan=1):
super().__init__()
self.cat = cat
if not self.cat:
self.std = nn.Parameter(torch.zeros([chan]))
def forward(self, x):
noise = torch.randn_like(x[:, :1])
if self.cat:
x = torch.cat([x, noise], dim=1)
else:
std_shape = (-1,) + (1,) * (x.dim() - 2)
noise = self.std.view(std_shape) * noise
x = x + noise
return x + noise
class D(nn.Module):
def __init__(self, in_chan, out_chan, scale_factor=16,
chan_base=512, chan_min=64, chan_max=512):
super().__init__()
self.scale_factor = scale_factor
num_blocks = round(log2(self.scale_factor))
assert chan_min <= chan_max
def chan(b):
if b >= 0:
c = chan_base >> b
else:
c = chan_base << -b
c = max(c, chan_min)
c = min(c, chan_max)
return c
self.block0 = nn.Sequential(
nn.Conv3d(in_chan, chan(num_blocks), 1),
nn.LeakyReLU(0.2, True),
)
self.blocks = nn.ModuleList()
for b in reversed(range(num_blocks)):
prev_chan, next_chan = chan(b+1), chan(b)
self.blocks.append(ResBlock(prev_chan, next_chan))
self.block9 = nn.Sequential(
nn.Conv3d(chan(0), chan(-1), 1),
nn.LeakyReLU(0.2, True),
nn.Conv3d(chan(-1), 1, 1),
)
def forward(self, x):
x = self.block0(x)
for block in self.blocks:
x = block(x)
x = self.block9(x)
return x
class ResBlock(nn.Module):
"""The residual block of the StyleGAN2 discriminator.
Downsampling are all linear, not strided convolution.
Notes
-----
next_size = (prev_size - 4) // 2
"""
def __init__(self, prev_chan, next_chan):
super().__init__()
self.conv = nn.Sequential(
nn.Conv3d(prev_chan, prev_chan, 3),
nn.LeakyReLU(0.2, True),
nn.Conv3d(prev_chan, next_chan, 3),
nn.LeakyReLU(0.2, True),
)
self.skip = nn.Conv3d(prev_chan, next_chan, 1)
self.downsample = Resampler(3, 0.5)
def forward(self, x):
y = self.conv(x)
x = self.skip(x)
x = narrow_by(x, 2)
x = x + y
x = self.downsample(x)
return x

View file

@ -6,7 +6,7 @@ from .narrow import narrow_like
class UNet(nn.Module):
def __init__(self, in_chan, out_chan):
def __init__(self, in_chan, out_chan, **kwargs):
super().__init__()
self.conv_l0 = ConvBlock(in_chan, 64, seq='CAC')

View file

@ -6,7 +6,7 @@ from .narrow import narrow_like
class VNet(nn.Module):
def __init__(self, in_chan, out_chan):
def __init__(self, in_chan, out_chan, **kwargs):
super().__init__()
self.conv_l0 = ResBlock(in_chan, 64, seq='CAC')
@ -46,7 +46,7 @@ class VNet(nn.Module):
class VNetFat(nn.Module):
def __init__(self, in_chan, out_chan):
def __init__(self, in_chan, out_chan, **kwargs):
super().__init__()
self.conv_l0 = nn.Sequential(

View file

@ -41,7 +41,7 @@ def test(args):
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(in_chan), sum(out_chan))
model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at)
criterion = criterion()

View file

@ -120,7 +120,8 @@ def gpu_worker(local_rank, node, args):
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = import_attr(args.model, models.__name__, args.callback_at)
model = model(sum(args.in_chan), sum(args.out_chan))
model = model(sum(args.in_chan), sum(args.out_chan),
scale_factor=args.scale_factor)
model.to(device)
model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group())