parent
a61990ee45
commit
01c2e45430
@ -3,10 +3,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .narrow import narrow_like
|
from .narrow import narrow_like
|
||||||
from .swish import Swish
|
|
||||||
|
|
||||||
from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
|
from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
|
||||||
|
|
||||||
|
|
||||||
class ConvStyledBlock(nn.Module):
|
class ConvStyledBlock(nn.Module):
|
||||||
"""Convolution blocks of the form specified by `seq`.
|
"""Convolution blocks of the form specified by `seq`.
|
||||||
|
|
||||||
@ -58,7 +57,7 @@ class ConvStyledBlock(nn.Module):
|
|||||||
elif l == 'A':
|
elif l == 'A':
|
||||||
return LeakyReLUStyled()
|
return LeakyReLUStyled()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('layer type {} not supported'.format(l))
|
raise ValueError('layer type {} not supported'.format(l))
|
||||||
|
|
||||||
def _setup_conv(self):
|
def _setup_conv(self):
|
||||||
self.idx_conv += 1
|
self.idx_conv += 1
|
||||||
@ -94,7 +93,7 @@ class ResStyledBlock(ConvStyledBlock):
|
|||||||
See `ConvStyledBlock` for `seq` types.
|
See `ConvStyledBlock` for `seq` types.
|
||||||
"""
|
"""
|
||||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||||
seq='CBACBA', last_act=None):
|
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
||||||
if last_act is None:
|
if last_act is None:
|
||||||
last_act = seq[-1] == 'A'
|
last_act = seq[-1] == 'A'
|
||||||
elif last_act and seq[-1] != 'A':
|
elif last_act and seq[-1] != 'A':
|
||||||
@ -107,7 +106,8 @@ class ResStyledBlock(ConvStyledBlock):
|
|||||||
if last_act:
|
if last_act:
|
||||||
seq = seq[:-1]
|
seq = seq[:-1]
|
||||||
|
|
||||||
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
||||||
|
kernel_size=kernel_size, stride=stride, seq=seq)
|
||||||
|
|
||||||
if last_act:
|
if last_act:
|
||||||
self.act = LeakyReLUStyled()
|
self.act = LeakyReLUStyled()
|
||||||
|
@ -27,7 +27,10 @@ class StyledVNet(nn.Module):
|
|||||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
|
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||||
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
|
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
|
||||||
|
|
||||||
self.bypass = in_chan == out_chan
|
if bypass is None:
|
||||||
|
self.bypass = in_chan == out_chan
|
||||||
|
else:
|
||||||
|
self.bypass = bypass
|
||||||
|
|
||||||
def forward(self, x, s):
|
def forward(self, x, s):
|
||||||
if self.bypass:
|
if self.bypass:
|
||||||
|
Loading…
Reference in New Issue
Block a user