Fix augmentation bug
This commit is contained in:
parent
1f89e894cc
commit
46a3a3a97d
@ -156,8 +156,8 @@ def flip(fields, axes, ndim):
|
|||||||
if x.shape[0] == ndim: # flip vector components
|
if x.shape[0] == ndim: # flip vector components
|
||||||
x[axes] = - x[axes]
|
x[axes] = - x[axes]
|
||||||
|
|
||||||
axes = (1 + axes).tolist()
|
shifted_axes = (1 + axes).tolist()
|
||||||
x = torch.flip(x, axes)
|
x = torch.flip(x, shifted_axes)
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
@ -172,8 +172,8 @@ def perm(fields, axes, ndim):
|
|||||||
if x.shape[0] == ndim: # permutate vector components
|
if x.shape[0] == ndim: # permutate vector components
|
||||||
x = x[axes]
|
x = x[axes]
|
||||||
|
|
||||||
axes = [0] + (1 + axes).tolist()
|
shifted_axes = [0] + (1 + axes).tolist()
|
||||||
x = x.permute(axes)
|
x = x.permute(shifted_axes)
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user