Add pixel-shifting data augmentation

This commit is contained in:
Yin Li 2020-07-11 00:46:13 -04:00
parent 8a95d69818
commit 28ec245a7a
4 changed files with 18 additions and 2 deletions

View File

@ -100,6 +100,10 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for validation target data') help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true', parser.add_argument('--augment', action='store_true',
help='enable data augmentation of axis flipping and permutation') help='enable data augmentation of axis flipping and permutation')
parser.add_argument('--aug-shift', type=int,
help='data augmentation by shifting [0, aug_shift) pixels, '
'useful for models that treat neighboring pixels differently, '
'e.g. with strided convolutions')
parser.add_argument('--aug-add', type=float, parser.add_argument('--aug-add', type=float,
help='additive data augmentation, (normal) std, ' help='additive data augmentation, (normal) std, '
'same factor for all fields') 'same factor for all fields')

View File

@ -22,7 +22,11 @@ class FieldDataset(Dataset):
Likewise for `tgt_norms`. Likewise for `tgt_norms`.
Scalar and vector fields can be augmented by flipping and permutating the axes. Scalar and vector fields can be augmented by flipping and permutating the axes.
In 3D these form the full octahedral symmetry known as the Oh point group. In 3D these form the full octahedral symmetry, the Oh group of order 48.
In 2D this is the dihedral group D4 of order 8.
1D is not supported, but can be done easily by preprocessing.
Fields can be augmented by random shift by a few pixels, useful for models
that treat neighboring pixels differently, e.g. with strided convolutions.
Additive and multiplicative augmentation are also possible, but with all fields Additive and multiplicative augmentation are also possible, but with all fields
added or multiplied by the same factor. added or multiplied by the same factor.
@ -44,7 +48,7 @@ class FieldDataset(Dataset):
""" """
def __init__(self, in_patterns, tgt_patterns, def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, callback_at=None, in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_add=None, aug_mul=None, augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None, crop=None, crop_start=None, crop_stop=None, crop_step=None,
pad=0, scale_factor=1, pad=0, scale_factor=1,
cache=False, cache_maxsize=None, div_data=False, cache=False, cache_maxsize=None, div_data=False,
@ -85,6 +89,7 @@ class FieldDataset(Dataset):
self.augment = augment self.augment = augment
if self.ndim == 1 and self.augment: if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields') raise ValueError('cannot augment 1D fields')
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
self.aug_add = aug_add self.aug_add = aug_add
self.aug_mul = aug_mul self.aug_mul = aug_mul
@ -163,6 +168,10 @@ class FieldDataset(Dataset):
anchor = self.anchors[idx % self.ncrop] anchor = self.anchors[idx % self.ncrop]
for d, shift in enumerate(self.aug_shift):
if shift is not None:
anchor[d] += torch.randint(shift, (1,))
in_fields = crop(in_fields, anchor, self.crop, self.pad) in_fields = crop(in_fields, anchor, self.crop, self.pad)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor, tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor, self.crop * self.scale_factor,

View File

@ -21,6 +21,7 @@ def test(args):
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at, callback_at=args.callback_at,
augment=False, augment=False,
aug_shift=None,
aug_add=None, aug_add=None,
aug_mul=None, aug_mul=None,
crop=args.crop, crop=args.crop,

View File

@ -65,6 +65,7 @@ def gpu_worker(local_rank, node, args):
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at, callback_at=args.callback_at,
augment=args.augment, augment=args.augment,
aug_shift=args.aug_shift,
aug_add=args.aug_add, aug_add=args.aug_add,
aug_mul=args.aug_mul, aug_mul=args.aug_mul,
crop=args.crop, crop=args.crop,
@ -107,6 +108,7 @@ def gpu_worker(local_rank, node, args):
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
callback_at=args.callback_at, callback_at=args.callback_at,
augment=False, augment=False,
aug_shift=None,
aug_add=None, aug_add=None,
aug_mul=None, aug_mul=None,
crop=args.crop, crop=args.crop,