Remove shape dependence from narrow_by

Good for torchscript
This commit is contained in:
Yin Li 2020-07-15 21:36:37 -04:00
parent bbf77c9f91
commit 7d3a598080

View File

@ -5,9 +5,8 @@ import torch.nn as nn
def narrow_by(a, c): def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges. """Narrow a by size c symmetrically on all edges.
""" """
for d in range(2, a.dim()): ind = [slice(None)] * 2 + [slice(c, -c)] * (a.dim() - 2)
a = a.narrow(d, c, a.shape[d] - 2 * c) return a[tuple(ind)]
return a
def narrow_cast(*tensors): def narrow_cast(*tensors):