Fix augmentation bug
This commit is contained in:
parent
1f89e894cc
commit
46a3a3a97d
@ -156,8 +156,8 @@ def flip(fields, axes, ndim):
|
||||
if x.shape[0] == ndim: # flip vector components
|
||||
x[axes] = - x[axes]
|
||||
|
||||
axes = (1 + axes).tolist()
|
||||
x = torch.flip(x, axes)
|
||||
shifted_axes = (1 + axes).tolist()
|
||||
x = torch.flip(x, shifted_axes)
|
||||
|
||||
new_fields.append(x)
|
||||
|
||||
@ -172,8 +172,8 @@ def perm(fields, axes, ndim):
|
||||
if x.shape[0] == ndim: # permutate vector components
|
||||
x = x[axes]
|
||||
|
||||
axes = [0] + (1 + axes).tolist()
|
||||
x = x.permute(axes)
|
||||
shifted_axes = [0] + (1 + axes).tolist()
|
||||
x = x.permute(shifted_axes)
|
||||
|
||||
new_fields.append(x)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user