Add a simple PatchGAN
This commit is contained in:
parent
c68b9928ee
commit
cdb00ebd8d
18
map2map/models/patchgan.py
Normal file
18
map2map/models/patchgan.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .conv import ConvBlock
|
||||||
|
|
||||||
|
|
||||||
|
class PatchGAN(nn.Module):
|
||||||
|
def __init__(self, in_chan, out_chan=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.convs = nn.Sequential(
|
||||||
|
ConvBlock(in_chan, 64, seq='CA'),
|
||||||
|
ConvBlock(64, 128, seq='CBA'),
|
||||||
|
ConvBlock(128, 256, seq='CBA'),
|
||||||
|
nn.Conv3d(256, out_chan, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.convs(x)
|
Loading…
Reference in New Issue
Block a user