Add a simple PatchGAN

This commit is contained in:
Yin Li 2020-01-22 14:01:18 -05:00
parent c68b9928ee
commit cdb00ebd8d

View File

@ -0,0 +1,18 @@
import torch.nn as nn
from .conv import ConvBlock
class PatchGAN(nn.Module):
def __init__(self, in_chan, out_chan=1):
super().__init__()
self.convs = nn.Sequential(
ConvBlock(in_chan, 64, seq='CA'),
ConvBlock(64, 128, seq='CBA'),
ConvBlock(128, 256, seq='CBA'),
nn.Conv3d(256, out_chan, 1)
)
def forward(self, x):
return self.convs(x)