Add optional trailing activation to residual block
This commit is contained in:
parent
a1e5399311
commit
13edf3b96d
@ -75,16 +75,28 @@ class ConvBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class ResBlock(ConvBlock):
|
class ResBlock(ConvBlock):
|
||||||
"""Residual convolution blocks of the form specified by `seq`. Input is added
|
"""Residual convolution blocks of the form specified by `seq`.
|
||||||
to the residual followed by an optional activation (trailing `'A'` in `seq`).
|
Input, via a skip connection, is added to the residual followed by an
|
||||||
|
optional activation.
|
||||||
|
|
||||||
|
The skip connection is identity if `out_chan` is omitted, otherwise it uses
|
||||||
|
a size 1 "convolution", i.e. one can trigger the latter by setting
|
||||||
|
`out_chan` even if it equals `in_chan`.
|
||||||
|
|
||||||
|
A trailing `'A'` in seq can either operate before or after the addition,
|
||||||
|
depending on the boolean value of `last_act`.
|
||||||
|
|
||||||
See `ConvBlock` for `seq` types.
|
See `ConvBlock` for `seq` types.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
def __init__(self, in_chan, out_chan=None, mid_chan=None,
|
||||||
seq='CBACBA'):
|
seq='CBACBA', last_act=True):
|
||||||
super().__init__(in_chan, out_chan=out_chan,
|
if seq[-1] == 'A' and last_act:
|
||||||
mid_chan=mid_chan,
|
seq = seq[:-1]
|
||||||
seq=seq[:-1] if seq[-1] == 'A' else seq)
|
self.act = nn.LeakyReLU()
|
||||||
|
else:
|
||||||
|
self.act = None
|
||||||
|
|
||||||
|
super().__init__(in_chan, out_chan=out_chan, mid_chan=mid_chan, seq=seq)
|
||||||
|
|
||||||
if out_chan is None:
|
if out_chan is None:
|
||||||
self.skip = None
|
self.skip = None
|
||||||
@ -95,11 +107,6 @@ class ResBlock(ConvBlock):
|
|||||||
raise NotImplementedError('upsample and downsample layers '
|
raise NotImplementedError('upsample and downsample layers '
|
||||||
'not supported yet')
|
'not supported yet')
|
||||||
|
|
||||||
if seq[-1] == 'A':
|
|
||||||
self.act = nn.LeakyReLU()
|
|
||||||
else:
|
|
||||||
self.act = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x
|
y = x
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user