Add optional trailing activation to residual block
This commit is contained in:
parent
a1e5399311
commit
13edf3b96d
1 changed files with 18 additions and 11 deletions
|
@ -75,16 +75,28 @@ class ConvBlock(nn.Module):
|
|||
|
||||
|
||||
class ResBlock(ConvBlock):
|
||||
"""Residual convolution blocks of the form specified by `seq`. Input is added
|
||||
to the residual followed by an optional activation (trailing `'A'` in `seq`).
|
||||
"""Residual convolution blocks of the form specified by `seq`.
|
||||
Input, via a skip connection, is added to the residual followed by an
|
||||
optional activation.
|
||||
|
||||
The skip connection is identity if `out_chan` is omitted, otherwise it uses
|
||||
a size 1 "convolution", i.e. one can trigger the latter by setting
|
||||
`out_chan` even if it equals `in_chan`.
|
||||
|
||||
A trailing `'A'` in seq can either operate before or after the addition,
|
||||
depending on the boolean value of `last_act`.
|
||||
|
||||
See `ConvBlock` for `seq` types.
|
||||
"""
|
||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||
seq='CBACBA'):
|
||||
super().__init__(in_chan, out_chan=out_chan,
|
||||
mid_chan=mid_chan,
|
||||
seq=seq[:-1] if seq[-1] == 'A' else seq)
|
||||
seq='CBACBA', last_act=True):
|
||||
if seq[-1] == 'A' and last_act:
|
||||
seq = seq[:-1]
|
||||
self.act = nn.LeakyReLU()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
||||
|
||||
if out_chan is None:
|
||||
self.skip = None
|
||||
|
@ -95,11 +107,6 @@ class ResBlock(ConvBlock):
|
|||
raise NotImplementedError('upsample and downsample layers '
|
||||
'not supported yet')
|
||||
|
||||
if seq[-1] == 'A':
|
||||
self.act = nn.LeakyReLU()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
|
||||
|
|
Loading…
Reference in a new issue