Add optional trailing activation to residual block

This commit is contained in:
Yin Li 2020-08-09 12:18:17 -07:00
parent a1e5399311
commit 13edf3b96d

View File

@ -75,16 +75,28 @@ class ConvBlock(nn.Module):
class ResBlock(ConvBlock): class ResBlock(ConvBlock):
"""Residual convolution blocks of the form specified by `seq`. Input is added """Residual convolution blocks of the form specified by `seq`.
to the residual followed by an optional activation (trailing `'A'` in `seq`). Input, via a skip connection, is added to the residual followed by an
optional activation.
The skip connection is identity if `out_chan` is omitted, otherwise it uses
a size 1 "convolution", i.e. one can trigger the latter by setting
`out_chan` even if it equals `in_chan`.
A trailing `'A'` in seq can either operate before or after the addition,
depending on the boolean value of `last_act`.
See `ConvBlock` for `seq` types. See `ConvBlock` for `seq` types.
""" """
def __init__(self, in_chan, out_chan=None, mid_chan=None, def __init__(self, in_chan, out_chan=None, mid_chan=None,
seq='CBACBA'): seq='CBACBA', last_act=True):
super().__init__(in_chan, out_chan=out_chan, if seq[-1] == 'A' and last_act:
mid_chan=mid_chan, seq = seq[:-1]
seq=seq[:-1] if seq[-1] == 'A' else seq) self.act = nn.LeakyReLU()
else:
self.act = None
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
if out_chan is None: if out_chan is None:
self.skip = None self.skip = None
@ -95,11 +107,6 @@ class ResBlock(ConvBlock):
raise NotImplementedError('upsample and downsample layers ' raise NotImplementedError('upsample and downsample layers '
'not supported yet') 'not supported yet')
if seq[-1] == 'A':
self.act = nn.LeakyReLU()
else:
self.act = None
def forward(self, x): def forward(self, x):
y = x y = x