Add pixel-shifting data augmentation

This commit is contained in:
Yin Li 2020-07-11 00:46:13 -04:00
parent 8a95d69818
commit 28ec245a7a
4 changed files with 18 additions and 2 deletions

View file

@ -100,6 +100,10 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true',
help='enable data augmentation of axis flipping and permutation')
parser.add_argument('--aug-shift', type=int,
help='data augmentation by shifting [0, aug_shift) pixels, '
'useful for models that treat neighboring pixels differently, '
'e.g. with strided convolutions')
parser.add_argument('--aug-add', type=float,
help='additive data augmentation, (normal) std, '
'same factor for all fields')

View file

@ -22,7 +22,11 @@ class FieldDataset(Dataset):
Likewise for `tgt_norms`.
Scalar and vector fields can be augmented by flipping and permutating the axes.
In 3D these form the full octahedral symmetry known as the Oh point group.
In 3D these form the full octahedral symmetry, the Oh group of order 48.
In 2D this is the dihedral group D4 of order 8.
1D is not supported, but can be done easily by preprocessing.
Fields can be augmented by random shift by a few pixels, useful for models
that treat neighboring pixels differently, e.g. with strided convolutions.
Additive and multiplicative augmentation are also possible, but with all fields
added or multiplied by the same factor.
@ -44,7 +48,7 @@ class FieldDataset(Dataset):
"""
def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None,
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
pad=0, scale_factor=1,
cache=False, cache_maxsize=None, div_data=False,
@ -85,6 +89,7 @@ class FieldDataset(Dataset):
self.augment = augment
if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields')
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
self.aug_add = aug_add
self.aug_mul = aug_mul
@ -163,6 +168,10 @@ class FieldDataset(Dataset):
anchor = self.anchors[idx % self.ncrop]
for d, shift in enumerate(self.aug_shift):
if shift is not None:
anchor[d] += torch.randint(shift, (1,))
in_fields = crop(in_fields, anchor, self.crop, self.pad)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor,

View file

@ -21,6 +21,7 @@ def test(args):
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=args.crop,

View file

@ -65,6 +65,7 @@ def gpu_worker(local_rank, node, args):
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=args.augment,
aug_shift=args.aug_shift,
aug_add=args.aug_add,
aug_mul=args.aug_mul,
crop=args.crop,
@ -107,6 +108,7 @@ def gpu_worker(local_rank, node, args):
tgt_norms=args.tgt_norms,
callback_at=args.callback_at,
augment=False,
aug_shift=None,
aug_add=None,
aug_mul=None,
crop=args.crop,