Merge remote-tracking branch 'refs/remotes/origin/lag2eul' into lag2eul

This commit is contained in:
Yin Li 2020-07-15 19:04:34 -07:00
commit 5bc099251f
2 changed files with 3 additions and 4 deletions

View File

@ -70,7 +70,7 @@ def add_common_args(parser):
parser.add_argument('--batches', type=int, required=True, parser.add_argument('--batches', type=int, required=True,
help='mini-batch size, per GPU in training or in total in testing') help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', default=-8, type=int, parser.add_argument('--loader-workers', default=-2, type=int,
help='number of subprocesses per data loader. ' help='number of subprocesses per data loader. '
'0 to disable multiprocessing; ' '0 to disable multiprocessing; '
'negative number to multiply by the batch size') 'negative number to multiply by the batch size')

View File

@ -5,9 +5,8 @@ import torch.nn as nn
def narrow_by(a, c): def narrow_by(a, c):
"""Narrow a by size c symmetrically on all edges. """Narrow a by size c symmetrically on all edges.
""" """
for d in range(2, a.dim()): ind = [slice(None)] * 2 + [slice(c, -c)] * (a.dim() - 2)
a = a.narrow(d, c, a.shape[d] - 2 * c) return a[tuple(ind)]
return a
def narrow_cast(*tensors): def narrow_cast(*tensors):