Fix module before __init__ bug

This commit is contained in:
Yin Li 2020-08-12 09:10:58 -07:00
parent 61ca400942
commit ebd962e333

View File

@ -1,3 +1,4 @@
import warnings
import torch
import torch.nn as nn
@ -84,20 +85,31 @@ class ResBlock(ConvBlock):
`out_chan` even if it equals `in_chan`.
A trailing `'A'` in seq can either operate before or after the addition,
depending on the boolean value of `last_act`.
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
See `ConvBlock` for `seq` types.
"""
def __init__(self, in_chan, out_chan=None, mid_chan=None,
seq='CBACBA', last_act=True):
if seq[-1] == 'A' and last_act:
seq='CBACBA', last_act=None):
if last_act is None:
last_act = seq[-1] == 'A'
elif last_act and seq[-1] != 'A':
warnings.warn(
'Disabling last_act without trailing activation in seq',
RuntimeWarning,
)
last_act = False
if last_act:
seq = seq[:-1]
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
if last_act:
self.act = nn.LeakyReLU()
else:
self.act = None
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
if out_chan is None:
self.skip = None
else: