parent
a61990ee45
commit
01c2e45430
@ -3,10 +3,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .narrow import narrow_like
|
||||
from .swish import Swish
|
||||
|
||||
from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
|
||||
|
||||
|
||||
class ConvStyledBlock(nn.Module):
|
||||
"""Convolution blocks of the form specified by `seq`.
|
||||
|
||||
@ -58,7 +57,7 @@ class ConvStyledBlock(nn.Module):
|
||||
elif l == 'A':
|
||||
return LeakyReLUStyled()
|
||||
else:
|
||||
raise NotImplementedError('layer type {} not supported'.format(l))
|
||||
raise ValueError('layer type {} not supported'.format(l))
|
||||
|
||||
def _setup_conv(self):
|
||||
self.idx_conv += 1
|
||||
@ -94,7 +93,7 @@ class ResStyledBlock(ConvStyledBlock):
|
||||
See `ConvStyledBlock` for `seq` types.
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||
seq='CBACBA', last_act=None):
|
||||
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
||||
if last_act is None:
|
||||
last_act = seq[-1] == 'A'
|
||||
elif last_act and seq[-1] != 'A':
|
||||
@ -107,7 +106,8 @@ class ResStyledBlock(ConvStyledBlock):
|
||||
if last_act:
|
||||
seq = seq[:-1]
|
||||
|
||||
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
||||
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
||||
kernel_size=kernel_size, stride=stride, seq=seq)
|
||||
|
||||
if last_act:
|
||||
self.act = LeakyReLUStyled()
|
||||
|
@ -27,7 +27,10 @@ class StyledVNet(nn.Module):
|
||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
|
||||
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
|
||||
|
||||
self.bypass = in_chan == out_chan
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
else:
|
||||
self.bypass = bypass
|
||||
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
|
Loading…
Reference in New Issue
Block a user