Add pixel-shifting data augmentation
This commit is contained in:
parent
8a95d69818
commit
28ec245a7a
@ -100,6 +100,10 @@ def add_train_args(parser):
|
|||||||
help='comma-sep. list of glob patterns for validation target data')
|
help='comma-sep. list of glob patterns for validation target data')
|
||||||
parser.add_argument('--augment', action='store_true',
|
parser.add_argument('--augment', action='store_true',
|
||||||
help='enable data augmentation of axis flipping and permutation')
|
help='enable data augmentation of axis flipping and permutation')
|
||||||
|
parser.add_argument('--aug-shift', type=int,
|
||||||
|
help='data augmentation by shifting [0, aug_shift) pixels, '
|
||||||
|
'useful for models that treat neighboring pixels differently, '
|
||||||
|
'e.g. with strided convolutions')
|
||||||
parser.add_argument('--aug-add', type=float,
|
parser.add_argument('--aug-add', type=float,
|
||||||
help='additive data augmentation, (normal) std, '
|
help='additive data augmentation, (normal) std, '
|
||||||
'same factor for all fields')
|
'same factor for all fields')
|
||||||
|
@ -22,7 +22,11 @@ class FieldDataset(Dataset):
|
|||||||
Likewise for `tgt_norms`.
|
Likewise for `tgt_norms`.
|
||||||
|
|
||||||
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
||||||
In 3D these form the full octahedral symmetry known as the Oh point group.
|
In 3D these form the full octahedral symmetry, the Oh group of order 48.
|
||||||
|
In 2D this is the dihedral group D4 of order 8.
|
||||||
|
1D is not supported, but can be done easily by preprocessing.
|
||||||
|
Fields can be augmented by random shift by a few pixels, useful for models
|
||||||
|
that treat neighboring pixels differently, e.g. with strided convolutions.
|
||||||
Additive and multiplicative augmentation are also possible, but with all fields
|
Additive and multiplicative augmentation are also possible, but with all fields
|
||||||
added or multiplied by the same factor.
|
added or multiplied by the same factor.
|
||||||
|
|
||||||
@ -44,7 +48,7 @@ class FieldDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
in_norms=None, tgt_norms=None, callback_at=None,
|
in_norms=None, tgt_norms=None, callback_at=None,
|
||||||
augment=False, aug_add=None, aug_mul=None,
|
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||||
pad=0, scale_factor=1,
|
pad=0, scale_factor=1,
|
||||||
cache=False, cache_maxsize=None, div_data=False,
|
cache=False, cache_maxsize=None, div_data=False,
|
||||||
@ -85,6 +89,7 @@ class FieldDataset(Dataset):
|
|||||||
self.augment = augment
|
self.augment = augment
|
||||||
if self.ndim == 1 and self.augment:
|
if self.ndim == 1 and self.augment:
|
||||||
raise ValueError('cannot augment 1D fields')
|
raise ValueError('cannot augment 1D fields')
|
||||||
|
self.aug_shift = np.broadcast_to(aug_shift, (self.ndim,))
|
||||||
self.aug_add = aug_add
|
self.aug_add = aug_add
|
||||||
self.aug_mul = aug_mul
|
self.aug_mul = aug_mul
|
||||||
|
|
||||||
@ -163,6 +168,10 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
anchor = self.anchors[idx % self.ncrop]
|
anchor = self.anchors[idx % self.ncrop]
|
||||||
|
|
||||||
|
for d, shift in enumerate(self.aug_shift):
|
||||||
|
if shift is not None:
|
||||||
|
anchor[d] += torch.randint(shift, (1,))
|
||||||
|
|
||||||
in_fields = crop(in_fields, anchor, self.crop, self.pad)
|
in_fields = crop(in_fields, anchor, self.crop, self.pad)
|
||||||
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor,
|
self.crop * self.scale_factor,
|
||||||
|
@ -21,6 +21,7 @@ def test(args):
|
|||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
callback_at=args.callback_at,
|
callback_at=args.callback_at,
|
||||||
augment=False,
|
augment=False,
|
||||||
|
aug_shift=None,
|
||||||
aug_add=None,
|
aug_add=None,
|
||||||
aug_mul=None,
|
aug_mul=None,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
|
@ -65,6 +65,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
callback_at=args.callback_at,
|
callback_at=args.callback_at,
|
||||||
augment=args.augment,
|
augment=args.augment,
|
||||||
|
aug_shift=args.aug_shift,
|
||||||
aug_add=args.aug_add,
|
aug_add=args.aug_add,
|
||||||
aug_mul=args.aug_mul,
|
aug_mul=args.aug_mul,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
@ -107,6 +108,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
callback_at=args.callback_at,
|
callback_at=args.callback_at,
|
||||||
augment=False,
|
augment=False,
|
||||||
|
aug_shift=None,
|
||||||
aug_add=None,
|
aug_add=None,
|
||||||
aug_mul=None,
|
aug_mul=None,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
|
Loading…
Reference in New Issue
Block a user