Add PatchGAN42 similar to the PatchGAN in pix2pix
This commit is contained in:
parent
45c1d57e72
commit
f831afbccf
@ -1,7 +1,7 @@
|
|||||||
from .unet import UNet
|
from .unet import UNet
|
||||||
from .vnet import VNet, VNetFat
|
from .vnet import VNet, VNetFat
|
||||||
from .pyramid import PyramidNet
|
from .pyramid import PyramidNet
|
||||||
from .patchgan import PatchGAN
|
from .patchgan import PatchGAN, PatchGAN42
|
||||||
|
|
||||||
from .conv import narrow_like
|
from .conv import narrow_like
|
||||||
|
|
||||||
|
@ -10,8 +10,34 @@ class PatchGAN(nn.Module):
|
|||||||
self.convs = nn.Sequential(
|
self.convs = nn.Sequential(
|
||||||
ConvBlock(in_chan, 32, seq='CA'),
|
ConvBlock(in_chan, 32, seq='CA'),
|
||||||
ConvBlock(32, 64, seq='CBA'),
|
ConvBlock(32, 64, seq='CBA'),
|
||||||
ConvBlock(64, 128, seq='CBA'),
|
ConvBlock(64, seq='CBA'),
|
||||||
nn.Conv3d(128, out_chan, 1)
|
ConvBlock(64, 32, seq='CBA'),
|
||||||
|
nn.Conv3d(32, out_chan, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.convs(x)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchGAN42(nn.Module):
|
||||||
|
"""PatchGAN similar to the one in pix2pix
|
||||||
|
"""
|
||||||
|
def __init__(self, in_chan, out_chan=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.convs = nn.Sequential(
|
||||||
|
nn.Conv3d(in_chan, 64, 4, stride=2),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
|
||||||
|
nn.Conv3d(64, 128, 4, stride=2),
|
||||||
|
nn.BatchNorm3d(128),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
|
||||||
|
nn.Conv3d(128, 256, 4, stride=2),
|
||||||
|
nn.BatchNorm3d(256),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
|
|
||||||
|
nn.Conv3d(256, out_chan, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -58,7 +58,7 @@ def gpu_worker(local_rank, args):
|
|||||||
noise_chan=args.noise_chan,
|
noise_chan=args.noise_chan,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
rank=rank,
|
rank=args.rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
@ -88,7 +88,7 @@ def gpu_worker(local_rank, args):
|
|||||||
noise_chan=args.noise_chan,
|
noise_chan=args.noise_chan,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
rank=rank,
|
rank=args.rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
|
Loading…
Reference in New Issue
Block a user