Add pixel-shifting data augmentation
This commit is contained in:
parent
8a95d69818
commit
28ec245a7a
4 changed files with 18 additions and 2 deletions
|
@ -100,6 +100,10 @@ def add_train_args(parser):
|
|||
help='comma-sep. list of glob patterns for validation target data')
|
||||
parser.add_argument('--augment', action='store_true',
|
||||
help='enable data augmentation of axis flipping and permutation')
|
||||
parser.add_argument('--aug-shift', type=int,
|
||||
help='data augmentation by shifting [0, aug_shift) pixels, '
|
||||
'useful for models that treat neighboring pixels differently, '
|
||||
'e.g. with strided convolutions')
|
||||
parser.add_argument('--aug-add', type=float,
|
||||
help='additive data augmentation, (normal) std, '
|
||||
'same factor for all fields')
|
||||
|
|
|
@ -22,7 +22,11 @@ class FieldDataset(Dataset):
|
|||
Likewise for `tgt_norms`.
|
||||
|
||||
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
||||
In 3D these form the full octahedral symmetry known as the Oh point group.
|
||||
In 3D these form the full octahedral symmetry, the Oh group of order 48.
|
||||
In 2D this is the dihedral group D4 of order 8.
|
||||
1D is not supported, but can be done easily by preprocessing.
|
||||
Fields can be augmented by random shift by a few pixels, useful for models
|
||||
that treat neighboring pixels differently, e.g. with strided convolutions.
|
||||
Additive and multiplicative augmentation are also possible, but with all fields
|
||||
added or multiplied by the same factor.
|
||||
|
||||
|
@ -44,7 +48,7 @@ class FieldDataset(Dataset):
|
|||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None, callback_at=None,
|
||||
augment=False, aug_add=None, aug_mul=None,
|
||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||
pad=0, scale_factor=1,
|
||||
cache=False, cache_maxsize=None, div_data=False,
|
||||
|
@ -85,6 +89,7 @@ class FieldDataset(Dataset):
|
|||
self.augment = augment
|
||||
if self.ndim == 1 and self.augment:
|
||||
raise ValueError('cannot augment 1D fields')
|
||||
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
|
||||
self.aug_add = aug_add
|
||||
self.aug_mul = aug_mul
|
||||
|
||||
|
@ -163,6 +168,10 @@ class FieldDataset(Dataset):
|
|||
|
||||
anchor = self.anchors[idx % self.ncrop]
|
||||
|
||||
for d, shift in enumerate(self.aug_shift):
|
||||
if shift is not None:
|
||||
anchor[d] += torch.randint(shift, (1,))
|
||||
|
||||
in_fields = crop(in_fields, anchor, self.crop, self.pad)
|
||||
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||
self.crop * self.scale_factor,
|
||||
|
|
|
@ -21,6 +21,7 @@ def test(args):
|
|||
tgt_norms=args.tgt_norms,
|
||||
callback_at=args.callback_at,
|
||||
augment=False,
|
||||
aug_shift=None,
|
||||
aug_add=None,
|
||||
aug_mul=None,
|
||||
crop=args.crop,
|
||||
|
|
|
@ -65,6 +65,7 @@ def gpu_worker(local_rank, node, args):
|
|||
tgt_norms=args.tgt_norms,
|
||||
callback_at=args.callback_at,
|
||||
augment=args.augment,
|
||||
aug_shift=args.aug_shift,
|
||||
aug_add=args.aug_add,
|
||||
aug_mul=args.aug_mul,
|
||||
crop=args.crop,
|
||||
|
@ -107,6 +108,7 @@ def gpu_worker(local_rank, node, args):
|
|||
tgt_norms=args.tgt_norms,
|
||||
callback_at=args.callback_at,
|
||||
augment=False,
|
||||
aug_shift=None,
|
||||
aug_add=None,
|
||||
aug_mul=None,
|
||||
crop=args.crop,
|
||||
|
|
Loading…
Reference in a new issue