Add additive and multiplicative augmentation

This commit is contained in:
Yin Li 2020-05-07 15:35:49 -04:00
parent 67e5ed9eb6
commit 897c3563db
4 changed files with 72 additions and 19 deletions

View File

@ -73,7 +73,13 @@ def add_train_args(parser):
parser.add_argument('--val-tgt-patterns', type=str_list, parser.add_argument('--val-tgt-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation target data') help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true', parser.add_argument('--augment', action='store_true',
help='enable training data augmentation') help='enable data augmentation of axis flipping and permutation')
parser.add_argument('--aug-add', type=float,
help='additive data augmentation, (normal) std, '
'same factor for all fields')
parser.add_argument('--aug-mul', type=float,
help='multiplicative data augmentation, (log-normal) std, '
'same factor for all fields')
parser.add_argument('--adv-model', type=str, parser.add_argument('--adv-model', type=str,
help='enable adversary model from .models') help='enable adversary model from .models')

View File

@ -18,7 +18,10 @@ class FieldDataset(Dataset):
`in_norms` is a list of of functions to normalize the input fields. `in_norms` is a list of of functions to normalize the input fields.
Likewise for `tgt_norms`. Likewise for `tgt_norms`.
Data augmentations are supported for scalar and vector fields. Scalar and vector fields can be augmented by flipping and permutating the axes.
In 3D these form the full octahedral symmetry known as the Oh point group.
Additive and multiplicative augmentation are also possible, but with all fields
added or multiplied by the same factor.
Input and target fields can be cropped. Input and target fields can be cropped.
Input fields can be padded assuming periodic boundary condition. Input fields can be padded assuming periodic boundary condition.
@ -32,8 +35,10 @@ class FieldDataset(Dataset):
""" """
def __init__(self, in_patterns, tgt_patterns, def __init__(self, in_patterns, tgt_patterns,
in_norms=None, tgt_norms=None, in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1, augment=False, aug_add=None, aug_mul=None,
cache=False, div_data=False, rank=None, world_size=None): crop=None, pad=0, scale_factor=1,
cache=False, div_data=False,
rank=None, world_size=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists)) self.in_files = list(zip(* in_file_lists))
@ -72,6 +77,8 @@ class FieldDataset(Dataset):
self.augment = augment self.augment = augment
if self.ndim == 1 and self.augment: if self.ndim == 1 and self.augment:
raise ValueError('cannot augment 1D fields') raise ValueError('cannot augment 1D fields')
self.aug_add = aug_add
self.aug_mul = aug_mul
if crop is None: if crop is None:
self.crop = self.size self.crop = self.size
@ -121,18 +128,6 @@ class FieldDataset(Dataset):
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
if self.augment:
flip_axes = torch.randint(2, (self.ndim,), dtype=torch.bool)
flip_axes = torch.arange(self.ndim)[flip_axes]
in_fields = flip(in_fields, flip_axes, self.ndim)
tgt_fields = flip(tgt_fields, flip_axes, self.ndim)
perm_axes = torch.randperm(self.ndim)
in_fields = perm(in_fields, perm_axes, self.ndim)
tgt_fields = perm(tgt_fields, perm_axes, self.ndim)
if self.in_norms is not None: if self.in_norms is not None:
for norm, x in zip(self.in_norms, in_fields): for norm, x in zip(self.in_norms, in_fields):
norm(x) norm(x)
@ -140,6 +135,21 @@ class FieldDataset(Dataset):
for norm, x in zip(self.tgt_norms, tgt_fields): for norm, x in zip(self.tgt_norms, tgt_fields):
norm(x) norm(x)
if self.augment:
in_fields, flip_axes = flip(in_fields, None, self.ndim)
tgt_fields, flip_axes = flip(tgt_fields, flip_axes, self.ndim)
in_fields, perm_axes = perm(in_fields, None, self.ndim)
tgt_fields, perm_axes = perm(tgt_fields, perm_axes, self.ndim)
if self.aug_add is not None:
add_fac = add(in_fields, None, self.aug_add)
add_fac = add(tgt_fields, add_fac, self.aug_add)
if self.aug_mul is not None:
mul_fac = mul(in_fields, None, self.aug_mul)
mul_fac = mul(tgt_fields, mul_fac, self.aug_mul)
in_fields = torch.cat(in_fields, dim=0) in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0)
@ -159,7 +169,11 @@ def crop(fields, start, crop, pad):
def flip(fields, axes, ndim): def flip(fields, axes, ndim):
assert ndim > 1, 'flipping is ambiguous for 1D vectors' assert ndim > 1, 'flipping is ambiguous for 1D scalars/vectors'
if axes is None:
axes = torch.randint(2, (ndim,), dtype=torch.bool)
axes = torch.arange(ndim)[axes]
new_fields = [] new_fields = []
for x in fields: for x in fields:
@ -171,12 +185,15 @@ def flip(fields, axes, ndim):
new_fields.append(x) new_fields.append(x)
return new_fields return new_fields, axes
def perm(fields, axes, ndim): def perm(fields, axes, ndim):
assert ndim > 1, 'permutation is not necessary for 1D fields' assert ndim > 1, 'permutation is not necessary for 1D fields'
if axes is None:
axes = torch.randperm(ndim)
new_fields = [] new_fields = []
for x in fields: for x in fields:
if x.shape[0] == ndim: # permutate vector components if x.shape[0] == ndim: # permutate vector components
@ -187,4 +204,28 @@ def perm(fields, axes, ndim):
new_fields.append(x) new_fields.append(x)
return new_fields return new_fields, axes
def add(fields, fac, std):
if fac is None:
x = fields[0]
fac = torch.zeros((x.shape[0],) + (1,) * (x.dim() - 1))
fac.normal_(mean=0, std=std)
for x in fields:
x += fac
return fac
def mul(fields, fac, std):
if fac is None:
x = fields[0]
fac = torch.ones((x.shape[0],) + (1,) * (x.dim() - 1))
fac.log_normal_(mean=0, std=std)
for x in fields:
x *= fac
return fac

View File

@ -19,6 +19,8 @@ def test(args):
in_norms=args.in_norms, in_norms=args.in_norms,
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
augment=False, augment=False,
aug_add=None,
aug_mul=None,
crop=args.crop, crop=args.crop,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,

View File

@ -63,6 +63,8 @@ def gpu_worker(local_rank, node, args):
in_norms=args.in_norms, in_norms=args.in_norms,
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
augment=args.augment, augment=args.augment,
aug_add=args.aug_add,
aug_mul=args.aug_mul,
crop=args.crop, crop=args.crop,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
@ -92,6 +94,8 @@ def gpu_worker(local_rank, node, args):
in_norms=args.in_norms, in_norms=args.in_norms,
tgt_norms=args.tgt_norms, tgt_norms=args.tgt_norms,
augment=False, augment=False,
aug_add=None,
aug_mul=None,
crop=args.crop, crop=args.crop,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,