Fixes and cleaning following 183a223 55b1a72 8544ff0

This commit is contained in:
Yin Li 2021-03-18 14:37:31 -04:00
parent a61990ee45
commit 01c2e45430
2 changed files with 9 additions and 6 deletions

View File

@ -3,10 +3,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from .narrow import narrow_like from .narrow import narrow_like
from .swish import Swish
from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled from .style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
class ConvStyledBlock(nn.Module): class ConvStyledBlock(nn.Module):
"""Convolution blocks of the form specified by `seq`. """Convolution blocks of the form specified by `seq`.
@ -58,7 +57,7 @@ class ConvStyledBlock(nn.Module):
elif l == 'A': elif l == 'A':
return LeakyReLUStyled() return LeakyReLUStyled()
else: else:
raise NotImplementedError('layer type {} not supported'.format(l)) raise ValueError('layer type {} not supported'.format(l))
def _setup_conv(self): def _setup_conv(self):
self.idx_conv += 1 self.idx_conv += 1
@ -94,7 +93,7 @@ class ResStyledBlock(ConvStyledBlock):
See `ConvStyledBlock` for `seq` types. See `ConvStyledBlock` for `seq` types.
""" """
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None, def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
seq='CBACBA', last_act=None): kernel_size=3, stride=1, seq='CBACBA', last_act=None):
if last_act is None: if last_act is None:
last_act = seq[-1] == 'A' last_act = seq[-1] == 'A'
elif last_act and seq[-1] != 'A': elif last_act and seq[-1] != 'A':
@ -107,7 +106,8 @@ class ResStyledBlock(ConvStyledBlock):
if last_act: if last_act:
seq = seq[:-1] seq = seq[:-1]
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq) super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan,
kernel_size=kernel_size, stride=stride, seq=seq)
if last_act: if last_act:
self.act = LeakyReLUStyled() self.act = LeakyReLUStyled()

View File

@ -27,7 +27,10 @@ class StyledVNet(nn.Module):
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA') self.up_r0 = ConvStyledBlock(style_size, 64, seq='UBA')
self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC') self.conv_r0 = ResStyledBlock(style_size, 128, out_chan, seq='CAC')
if bypass is None:
self.bypass = in_chan == out_chan self.bypass = in_chan == out_chan
else:
self.bypass = bypass
def forward(self, x, s): def forward(self, x, s):
if self.bypass: if self.bypass: