Add PatchGAN42 similar to the PatchGAN in pix2pix

This commit is contained in:
Yin Li 2020-02-07 14:32:43 -05:00
parent 45c1d57e72
commit f831afbccf
3 changed files with 31 additions and 5 deletions

View File

@ -1,7 +1,7 @@
from .unet import UNet from .unet import UNet
from .vnet import VNet, VNetFat from .vnet import VNet, VNetFat
from .pyramid import PyramidNet from .pyramid import PyramidNet
from .patchgan import PatchGAN from .patchgan import PatchGAN, PatchGAN42
from .conv import narrow_like from .conv import narrow_like

View File

@ -10,8 +10,34 @@ class PatchGAN(nn.Module):
self.convs = nn.Sequential( self.convs = nn.Sequential(
ConvBlock(in_chan, 32, seq='CA'), ConvBlock(in_chan, 32, seq='CA'),
ConvBlock(32, 64, seq='CBA'), ConvBlock(32, 64, seq='CBA'),
ConvBlock(64, 128, seq='CBA'), ConvBlock(64, seq='CBA'),
nn.Conv3d(128, out_chan, 1) ConvBlock(64, 32, seq='CBA'),
nn.Conv3d(32, out_chan, 1),
)
def forward(self, x):
return self.convs(x)
class PatchGAN42(nn.Module):
"""PatchGAN similar to the one in pix2pix
"""
def __init__(self, in_chan, out_chan=1):
super().__init__()
self.convs = nn.Sequential(
nn.Conv3d(in_chan, 64, 4, stride=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(64, 128, 4, stride=2),
nn.BatchNorm3d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(128, 256, 4, stride=2),
nn.BatchNorm3d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(256, out_chan, 1),
) )
def forward(self, x): def forward(self, x):

View File

@ -58,7 +58,7 @@ def gpu_worker(local_rank, args):
noise_chan=args.noise_chan, noise_chan=args.noise_chan,
cache=args.cache, cache=args.cache,
div_data=args.div_data, div_data=args.div_data,
rank=rank, rank=args.rank,
world_size=args.world_size, world_size=args.world_size,
) )
if not args.div_data: if not args.div_data:
@ -88,7 +88,7 @@ def gpu_worker(local_rank, args):
noise_chan=args.noise_chan, noise_chan=args.noise_chan,
cache=args.cache, cache=args.cache,
div_data=args.div_data, div_data=args.div_data,
rank=rank, rank=args.rank,
world_size=args.world_size, world_size=args.world_size,
) )
if not args.div_data: if not args.div_data: