Fix module before __init__ bug
This commit is contained in:
parent
61ca400942
commit
ebd962e333
@ -1,3 +1,4 @@
|
|||||||
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -84,20 +85,31 @@ class ResBlock(ConvBlock):
|
|||||||
`out_chan` even if it equals `in_chan`.
|
`out_chan` even if it equals `in_chan`.
|
||||||
|
|
||||||
A trailing `'A'` in seq can either operate before or after the addition,
|
A trailing `'A'` in seq can either operate before or after the addition,
|
||||||
depending on the boolean value of `last_act`.
|
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
|
||||||
|
|
||||||
See `ConvBlock` for `seq` types.
|
See `ConvBlock` for `seq` types.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||||
seq='CBACBA', last_act=True):
|
seq='CBACBA', last_act=None):
|
||||||
if seq[-1] == 'A' and last_act:
|
if last_act is None:
|
||||||
|
last_act = seq[-1] == 'A'
|
||||||
|
elif last_act and seq[-1] != 'A':
|
||||||
|
warnings.warn(
|
||||||
|
'Disabling last_act without trailing activation in seq',
|
||||||
|
RuntimeWarning,
|
||||||
|
)
|
||||||
|
last_act = False
|
||||||
|
|
||||||
|
if last_act:
|
||||||
seq = seq[:-1]
|
seq = seq[:-1]
|
||||||
|
|
||||||
|
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
||||||
|
|
||||||
|
if last_act:
|
||||||
self.act = nn.LeakyReLU()
|
self.act = nn.LeakyReLU()
|
||||||
else:
|
else:
|
||||||
self.act = None
|
self.act = None
|
||||||
|
|
||||||
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
|
||||||
|
|
||||||
if out_chan is None:
|
if out_chan is None:
|
||||||
self.skip = None
|
self.skip = None
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user