Remove shape dependence from narrow_by

Good for torchscript
This commit is contained in:
Yin Li 2020-07-15 21:36:37 -04:00
parent bbf77c9f91
commit 7d3a598080

View File

@ -5,9 +5,8 @@ import torch.nn as nn
def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges.
"""
for d in range(2, a.dim()):
a = a.narrow(d, c, a.shape[d] - 2 * c)
return a
ind = [slice(None)] * 2 + [slice(c, -c)] * (a.dim() - 2)
return a[tuple(ind)]
def narrow_cast(*tensors):