Merge remote-tracking branch 'refs/remotes/origin/lag2eul' into lag2eul
This commit is contained in:
commit
5bc099251f
@ -70,7 +70,7 @@ def add_common_args(parser):
|
|||||||
|
|
||||||
parser.add_argument('--batches', type=int, required=True,
|
parser.add_argument('--batches', type=int, required=True,
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
parser.add_argument('--loader-workers', default=-8, type=int,
|
parser.add_argument('--loader-workers', default=-2, type=int,
|
||||||
help='number of subprocesses per data loader. '
|
help='number of subprocesses per data loader. '
|
||||||
'0 to disable multiprocessing; '
|
'0 to disable multiprocessing; '
|
||||||
'negative number to multiply by the batch size')
|
'negative number to multiply by the batch size')
|
||||||
|
@ -5,9 +5,8 @@ import torch.nn as nn
|
|||||||
def narrow_by(a, c):
|
def narrow_by(a, c):
|
||||||
"""Narrow a by size c symmetrically on all edges.
|
"""Narrow a by size c symmetrically on all edges.
|
||||||
"""
|
"""
|
||||||
for d in range(2, a.dim()):
|
ind = [slice(None)] * 2 + [slice(c, -c)] * (a.dim() - 2)
|
||||||
a = a.narrow(d, c, a.shape[d] - 2 * c)
|
return a[tuple(ind)]
|
||||||
return a
|
|
||||||
|
|
||||||
|
|
||||||
def narrow_cast(*tensors):
|
def narrow_cast(*tensors):
|
||||||
|
Loading…
Reference in New Issue
Block a user