Remove shape dependence from narrow_by
Good for torchscript
This commit is contained in:
parent
bbf77c9f91
commit
7d3a598080
@ -5,9 +5,8 @@ import torch.nn as nn
|
||||
def narrow_by(a, c):
|
||||
"""Narrow a by size c symmetrically on all edges.
|
||||
"""
|
||||
for d in range(2, a.dim()):
|
||||
a = a.narrow(d, c, a.shape[d] - 2 * c)
|
||||
return a
|
||||
ind = [slice(None)] * 2 + [slice(c, -c)] * (a.dim() - 2)
|
||||
return a[tuple(ind)]
|
||||
|
||||
|
||||
def narrow_cast(*tensors):
|
||||
|
Loading…
Reference in New Issue
Block a user