Add srsgan model and scale_factor to model arguments
This commit is contained in:
parent
6dfde5ee7f
commit
265587922d
8 changed files with 229 additions and 10 deletions
|
@ -5,8 +5,8 @@ import torch.nn as nn
|
|||
def narrow_by(a, c):
|
||||
"""Narrow a by size c symmetrically on all edges.
|
||||
"""
|
||||
ind = [slice(None)] * 2 + [slice(c, -c)] * (a.dim() - 2)
|
||||
return a[tuple(ind)]
|
||||
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
|
||||
return a[ind]
|
||||
|
||||
|
||||
def narrow_cast(*tensors):
|
||||
|
|
|
@ -4,7 +4,7 @@ from .conv import ConvBlock
|
|||
|
||||
|
||||
class PatchGAN(nn.Module):
|
||||
def __init__(self, in_chan, out_chan=1):
|
||||
def __init__(self, in_chan, out_chan=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.convs = nn.Sequential(
|
||||
|
@ -21,7 +21,7 @@ class PatchGAN(nn.Module):
|
|||
class PatchGAN42(nn.Module):
|
||||
"""PatchGAN similar to the one in pix2pix
|
||||
"""
|
||||
def __init__(self, in_chan, out_chan=1):
|
||||
def __init__(self, in_chan, out_chan=1, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.convs = nn.Sequential(
|
||||
|
|
|
@ -5,7 +5,7 @@ from .conv import ConvBlock, ResBlock, narrow_like
|
|||
|
||||
|
||||
class PyramidNet(nn.Module):
|
||||
def __init__(self, in_chan, out_chan):
|
||||
def __init__(self, in_chan, out_chan, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.AvgPool3d(2, stride=2)
|
||||
|
|
218
map2map/models/srsgan.py
Normal file
218
map2map/models/srsgan.py
Normal file
|
@ -0,0 +1,218 @@
|
|||
from math import log2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .narrow import narrow_by, narrow_like
|
||||
from .resample import Resampler
|
||||
|
||||
|
||||
class G(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||
chan_base=512, chan_min=64, chan_max=512, cat_noise=True):
|
||||
super().__init__()
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
num_blocks = round(log2(self.scale_factor))
|
||||
|
||||
assert chan_min <= chan_max
|
||||
|
||||
def chan(b):
|
||||
c = chan_base >> b
|
||||
c = max(c, chan_min)
|
||||
c = min(c, chan_max)
|
||||
return c
|
||||
|
||||
self.block0 = nn.Sequential(
|
||||
nn.Conv3d(in_chan, chan(0), 1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for b in range(num_blocks):
|
||||
prev_chan, next_chan = chan(b), chan(b+1)
|
||||
self.blocks.append(
|
||||
SkipBlock(prev_chan, next_chan, out_chan, cat_noise))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block0(x)
|
||||
|
||||
y = None
|
||||
for block in self.blocks:
|
||||
x, y = block(x, y)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
class SkipBlock(nn.Module):
|
||||
"""The "I" block of the StyleGAN2 generator.
|
||||
|
||||
x_p y_p
|
||||
| |
|
||||
convolution linear upsample
|
||||
| |
|
||||
>--- projection ------>+
|
||||
| |
|
||||
v v
|
||||
x_n y_n
|
||||
|
||||
See Fig. 7 (b) upper in https://arxiv.org/abs/1912.04958
|
||||
Upsampling are all linear, not transposed convolution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prev_chan : number of channels of x_p
|
||||
next_chan : number of channels of x_n
|
||||
out_chan : number of channels of y_p and y_n
|
||||
cat_noise: concatenate noise if True, otherwise add noise
|
||||
|
||||
Notes
|
||||
-----
|
||||
next_size = 2 * prev_size - 6
|
||||
"""
|
||||
def __init__(self, prev_chan, next_chan, out_chan, cat_noise):
|
||||
super().__init__()
|
||||
|
||||
self.upsample = Resampler(3, 2)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
AddNoise(cat_noise, chan=prev_chan),
|
||||
self.upsample,
|
||||
nn.Conv3d(prev_chan + int(cat_noise), next_chan, 3),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
|
||||
AddNoise(cat_noise, chan=next_chan),
|
||||
nn.Conv3d(next_chan + int(cat_noise), next_chan, 3),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
AddNoise(cat_noise, chan=next_chan),
|
||||
nn.Conv3d(next_chan + int(cat_noise), out_chan, 1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.conv(x) # narrow by 3
|
||||
|
||||
if y is None:
|
||||
y = self.proj(x)
|
||||
else:
|
||||
y = self.upsample(y) # narrow by 1
|
||||
|
||||
y = narrow_by(y, 2)
|
||||
|
||||
y = y + self.proj(x)
|
||||
|
||||
return x, y
|
||||
|
||||
|
||||
class AddNoise(nn.Module):
|
||||
"""Add or concatenate noise.
|
||||
|
||||
Add noise if `cat=False`.
|
||||
The number of channels `chan` should be 1 (StyleGAN2)
|
||||
or that of the input (StyleGAN).
|
||||
"""
|
||||
def __init__(self, cat, chan=1):
|
||||
super().__init__()
|
||||
|
||||
self.cat = cat
|
||||
|
||||
if not self.cat:
|
||||
self.std = nn.Parameter(torch.zeros([chan]))
|
||||
|
||||
def forward(self, x):
|
||||
noise = torch.randn_like(x[:, :1])
|
||||
|
||||
if self.cat:
|
||||
x = torch.cat([x, noise], dim=1)
|
||||
else:
|
||||
std_shape = (-1,) + (1,) * (x.dim() - 2)
|
||||
noise = self.std.view(std_shape) * noise
|
||||
|
||||
x = x + noise
|
||||
|
||||
return x + noise
|
||||
|
||||
|
||||
class D(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, scale_factor=16,
|
||||
chan_base=512, chan_min=64, chan_max=512):
|
||||
super().__init__()
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
num_blocks = round(log2(self.scale_factor))
|
||||
|
||||
assert chan_min <= chan_max
|
||||
|
||||
def chan(b):
|
||||
if b >= 0:
|
||||
c = chan_base >> b
|
||||
else:
|
||||
c = chan_base << -b
|
||||
c = max(c, chan_min)
|
||||
c = min(c, chan_max)
|
||||
return c
|
||||
|
||||
self.block0 = nn.Sequential(
|
||||
nn.Conv3d(in_chan, chan(num_blocks), 1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for b in reversed(range(num_blocks)):
|
||||
prev_chan, next_chan = chan(b+1), chan(b)
|
||||
self.blocks.append(ResBlock(prev_chan, next_chan))
|
||||
|
||||
self.block9 = nn.Sequential(
|
||||
nn.Conv3d(chan(0), chan(-1), 1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv3d(chan(-1), 1, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block0(x)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.block9(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""The residual block of the StyleGAN2 discriminator.
|
||||
|
||||
Downsampling are all linear, not strided convolution.
|
||||
|
||||
Notes
|
||||
-----
|
||||
next_size = (prev_size - 4) // 2
|
||||
"""
|
||||
def __init__(self, prev_chan, next_chan):
|
||||
super().__init__()
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv3d(prev_chan, prev_chan, 3),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
|
||||
nn.Conv3d(prev_chan, next_chan, 3),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
)
|
||||
|
||||
self.skip = nn.Conv3d(prev_chan, next_chan, 1)
|
||||
|
||||
self.downsample = Resampler(3, 0.5)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv(x)
|
||||
|
||||
x = self.skip(x)
|
||||
x = narrow_by(x, 2)
|
||||
|
||||
x = x + y
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
return x
|
|
@ -6,7 +6,7 @@ from .narrow import narrow_like
|
|||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, in_chan, out_chan):
|
||||
def __init__(self, in_chan, out_chan, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.conv_l0 = ConvBlock(in_chan, 64, seq='CAC')
|
||||
|
|
|
@ -6,7 +6,7 @@ from .narrow import narrow_like
|
|||
|
||||
|
||||
class VNet(nn.Module):
|
||||
def __init__(self, in_chan, out_chan):
|
||||
def __init__(self, in_chan, out_chan, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.conv_l0 = ResBlock(in_chan, 64, seq='CAC')
|
||||
|
@ -46,7 +46,7 @@ class VNet(nn.Module):
|
|||
|
||||
|
||||
class VNetFat(nn.Module):
|
||||
def __init__(self, in_chan, out_chan):
|
||||
def __init__(self, in_chan, out_chan, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.conv_l0 = nn.Sequential(
|
||||
|
|
|
@ -41,7 +41,7 @@ def test(args):
|
|||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||
|
||||
model = import_attr(args.model, models.__name__, args.callback_at)
|
||||
model = model(sum(in_chan), sum(out_chan))
|
||||
model = model(sum(in_chan), sum(out_chan), scale_factor=args.scale_factor)
|
||||
criterion = import_attr(args.criterion, torch.nn.__name__, args.callback_at)
|
||||
criterion = criterion()
|
||||
|
||||
|
|
|
@ -120,7 +120,8 @@ def gpu_worker(local_rank, node, args):
|
|||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||
|
||||
model = import_attr(args.model, models.__name__, args.callback_at)
|
||||
model = model(sum(args.in_chan), sum(args.out_chan))
|
||||
model = model(sum(args.in_chan), sum(args.out_chan),
|
||||
scale_factor=args.scale_factor)
|
||||
model.to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
|
Loading…
Reference in a new issue