Fix module before __init__ bug
This commit is contained in:
parent
61ca400942
commit
ebd962e333
1 changed files with 17 additions and 5 deletions
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -84,20 +85,31 @@ class ResBlock(ConvBlock):
|
|||
`out_chan` even if it equals `in_chan`.
|
||||
|
||||
A trailing `'A'` in seq can either operate before or after the addition,
|
||||
depending on the boolean value of `last_act`.
|
||||
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
|
||||
|
||||
See `ConvBlock` for `seq` types.
|
||||
"""
|
||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||
seq='CBACBA', last_act=True):
|
||||
if seq[-1] == 'A' and last_act:
|
||||
seq='CBACBA', last_act=None):
|
||||
if last_act is None:
|
||||
last_act = seq[-1] == 'A'
|
||||
elif last_act and seq[-1] != 'A':
|
||||
warnings.warn(
|
||||
'Disabling last_act without trailing activation in seq',
|
||||
RuntimeWarning,
|
||||
)
|
||||
last_act = False
|
||||
|
||||
if last_act:
|
||||
seq = seq[:-1]
|
||||
|
||||
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
||||
|
||||
if last_act:
|
||||
self.act = nn.LeakyReLU()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
||||
|
||||
if out_chan is None:
|
||||
self.skip = None
|
||||
else:
|
||||
|
|
Loading…
Reference in a new issue