Add additive and multiplicative augmentation
This commit is contained in:
parent
67e5ed9eb6
commit
897c3563db
@ -73,7 +73,13 @@ def add_train_args(parser):
|
||||
parser.add_argument('--val-tgt-patterns', type=str_list,
|
||||
help='comma-sep. list of glob patterns for validation target data')
|
||||
parser.add_argument('--augment', action='store_true',
|
||||
help='enable training data augmentation')
|
||||
help='enable data augmentation of axis flipping and permutation')
|
||||
parser.add_argument('--aug-add', type=float,
|
||||
help='additive data augmentation, (normal) std, '
|
||||
'same factor for all fields')
|
||||
parser.add_argument('--aug-mul', type=float,
|
||||
help='multiplicative data augmentation, (log-normal) std, '
|
||||
'same factor for all fields')
|
||||
|
||||
parser.add_argument('--adv-model', type=str,
|
||||
help='enable adversary model from .models')
|
||||
|
@ -18,7 +18,10 @@ class FieldDataset(Dataset):
|
||||
`in_norms` is a list of of functions to normalize the input fields.
|
||||
Likewise for `tgt_norms`.
|
||||
|
||||
Data augmentations are supported for scalar and vector fields.
|
||||
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
||||
In 3D these form the full octahedral symmetry known as the Oh point group.
|
||||
Additive and multiplicative augmentation are also possible, but with all fields
|
||||
added or multiplied by the same factor.
|
||||
|
||||
Input and target fields can be cropped.
|
||||
Input fields can be padded assuming periodic boundary condition.
|
||||
@ -32,8 +35,10 @@ class FieldDataset(Dataset):
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns,
|
||||
in_norms=None, tgt_norms=None,
|
||||
augment=False, crop=None, pad=0, scale_factor=1,
|
||||
cache=False, div_data=False, rank=None, world_size=None):
|
||||
augment=False, aug_add=None, aug_mul=None,
|
||||
crop=None, pad=0, scale_factor=1,
|
||||
cache=False, div_data=False,
|
||||
rank=None, world_size=None):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
|
||||
@ -72,6 +77,8 @@ class FieldDataset(Dataset):
|
||||
self.augment = augment
|
||||
if self.ndim == 1 and self.augment:
|
||||
raise ValueError('cannot augment 1D fields')
|
||||
self.aug_add = aug_add
|
||||
self.aug_mul = aug_mul
|
||||
|
||||
if crop is None:
|
||||
self.crop = self.size
|
||||
@ -121,18 +128,6 @@ class FieldDataset(Dataset):
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||
|
||||
if self.augment:
|
||||
flip_axes = torch.randint(2, (self.ndim,), dtype=torch.bool)
|
||||
flip_axes = torch.arange(self.ndim)[flip_axes]
|
||||
|
||||
in_fields = flip(in_fields, flip_axes, self.ndim)
|
||||
tgt_fields = flip(tgt_fields, flip_axes, self.ndim)
|
||||
|
||||
perm_axes = torch.randperm(self.ndim)
|
||||
|
||||
in_fields = perm(in_fields, perm_axes, self.ndim)
|
||||
tgt_fields = perm(tgt_fields, perm_axes, self.ndim)
|
||||
|
||||
if self.in_norms is not None:
|
||||
for norm, x in zip(self.in_norms, in_fields):
|
||||
norm(x)
|
||||
@ -140,6 +135,21 @@ class FieldDataset(Dataset):
|
||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||
norm(x)
|
||||
|
||||
if self.augment:
|
||||
in_fields, flip_axes = flip(in_fields, None, self.ndim)
|
||||
tgt_fields, flip_axes = flip(tgt_fields, flip_axes, self.ndim)
|
||||
|
||||
in_fields, perm_axes = perm(in_fields, None, self.ndim)
|
||||
tgt_fields, perm_axes = perm(tgt_fields, perm_axes, self.ndim)
|
||||
|
||||
if self.aug_add is not None:
|
||||
add_fac = add(in_fields, None, self.aug_add)
|
||||
add_fac = add(tgt_fields, add_fac, self.aug_add)
|
||||
|
||||
if self.aug_mul is not None:
|
||||
mul_fac = mul(in_fields, None, self.aug_mul)
|
||||
mul_fac = mul(tgt_fields, mul_fac, self.aug_mul)
|
||||
|
||||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
|
||||
@ -159,7 +169,11 @@ def crop(fields, start, crop, pad):
|
||||
|
||||
|
||||
def flip(fields, axes, ndim):
|
||||
assert ndim > 1, 'flipping is ambiguous for 1D vectors'
|
||||
assert ndim > 1, 'flipping is ambiguous for 1D scalars/vectors'
|
||||
|
||||
if axes is None:
|
||||
axes = torch.randint(2, (ndim,), dtype=torch.bool)
|
||||
axes = torch.arange(ndim)[axes]
|
||||
|
||||
new_fields = []
|
||||
for x in fields:
|
||||
@ -171,12 +185,15 @@ def flip(fields, axes, ndim):
|
||||
|
||||
new_fields.append(x)
|
||||
|
||||
return new_fields
|
||||
return new_fields, axes
|
||||
|
||||
|
||||
def perm(fields, axes, ndim):
|
||||
assert ndim > 1, 'permutation is not necessary for 1D fields'
|
||||
|
||||
if axes is None:
|
||||
axes = torch.randperm(ndim)
|
||||
|
||||
new_fields = []
|
||||
for x in fields:
|
||||
if x.shape[0] == ndim: # permutate vector components
|
||||
@ -187,4 +204,28 @@ def perm(fields, axes, ndim):
|
||||
|
||||
new_fields.append(x)
|
||||
|
||||
return new_fields
|
||||
return new_fields, axes
|
||||
|
||||
|
||||
def add(fields, fac, std):
|
||||
if fac is None:
|
||||
x = fields[0]
|
||||
fac = torch.zeros((x.shape[0],) + (1,) * (x.dim() - 1))
|
||||
fac.normal_(mean=0, std=std)
|
||||
|
||||
for x in fields:
|
||||
x += fac
|
||||
|
||||
return fac
|
||||
|
||||
|
||||
def mul(fields, fac, std):
|
||||
if fac is None:
|
||||
x = fields[0]
|
||||
fac = torch.ones((x.shape[0],) + (1,) * (x.dim() - 1))
|
||||
fac.log_normal_(mean=0, std=std)
|
||||
|
||||
for x in fields:
|
||||
x *= fac
|
||||
|
||||
return fac
|
||||
|
@ -19,6 +19,8 @@ def test(args):
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
augment=False,
|
||||
aug_add=None,
|
||||
aug_mul=None,
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
|
@ -63,6 +63,8 @@ def gpu_worker(local_rank, node, args):
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
augment=args.augment,
|
||||
aug_add=args.aug_add,
|
||||
aug_mul=args.aug_mul,
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
@ -92,6 +94,8 @@ def gpu_worker(local_rank, node, args):
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
augment=False,
|
||||
aug_add=None,
|
||||
aug_mul=None,
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
|
Loading…
Reference in New Issue
Block a user