Add additive and multiplicative augmentation
This commit is contained in:
parent
67e5ed9eb6
commit
897c3563db
@ -73,7 +73,13 @@ def add_train_args(parser):
|
|||||||
parser.add_argument('--val-tgt-patterns', type=str_list,
|
parser.add_argument('--val-tgt-patterns', type=str_list,
|
||||||
help='comma-sep. list of glob patterns for validation target data')
|
help='comma-sep. list of glob patterns for validation target data')
|
||||||
parser.add_argument('--augment', action='store_true',
|
parser.add_argument('--augment', action='store_true',
|
||||||
help='enable training data augmentation')
|
help='enable data augmentation of axis flipping and permutation')
|
||||||
|
parser.add_argument('--aug-add', type=float,
|
||||||
|
help='additive data augmentation, (normal) std, '
|
||||||
|
'same factor for all fields')
|
||||||
|
parser.add_argument('--aug-mul', type=float,
|
||||||
|
help='multiplicative data augmentation, (log-normal) std, '
|
||||||
|
'same factor for all fields')
|
||||||
|
|
||||||
parser.add_argument('--adv-model', type=str,
|
parser.add_argument('--adv-model', type=str,
|
||||||
help='enable adversary model from .models')
|
help='enable adversary model from .models')
|
||||||
|
@ -18,7 +18,10 @@ class FieldDataset(Dataset):
|
|||||||
`in_norms` is a list of of functions to normalize the input fields.
|
`in_norms` is a list of of functions to normalize the input fields.
|
||||||
Likewise for `tgt_norms`.
|
Likewise for `tgt_norms`.
|
||||||
|
|
||||||
Data augmentations are supported for scalar and vector fields.
|
Scalar and vector fields can be augmented by flipping and permutating the axes.
|
||||||
|
In 3D these form the full octahedral symmetry known as the Oh point group.
|
||||||
|
Additive and multiplicative augmentation are also possible, but with all fields
|
||||||
|
added or multiplied by the same factor.
|
||||||
|
|
||||||
Input and target fields can be cropped.
|
Input and target fields can be cropped.
|
||||||
Input fields can be padded assuming periodic boundary condition.
|
Input fields can be padded assuming periodic boundary condition.
|
||||||
@ -32,8 +35,10 @@ class FieldDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns,
|
def __init__(self, in_patterns, tgt_patterns,
|
||||||
in_norms=None, tgt_norms=None,
|
in_norms=None, tgt_norms=None,
|
||||||
augment=False, crop=None, pad=0, scale_factor=1,
|
augment=False, aug_add=None, aug_mul=None,
|
||||||
cache=False, div_data=False, rank=None, world_size=None):
|
crop=None, pad=0, scale_factor=1,
|
||||||
|
cache=False, div_data=False,
|
||||||
|
rank=None, world_size=None):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
|
|
||||||
@ -72,6 +77,8 @@ class FieldDataset(Dataset):
|
|||||||
self.augment = augment
|
self.augment = augment
|
||||||
if self.ndim == 1 and self.augment:
|
if self.ndim == 1 and self.augment:
|
||||||
raise ValueError('cannot augment 1D fields')
|
raise ValueError('cannot augment 1D fields')
|
||||||
|
self.aug_add = aug_add
|
||||||
|
self.aug_mul = aug_mul
|
||||||
|
|
||||||
if crop is None:
|
if crop is None:
|
||||||
self.crop = self.size
|
self.crop = self.size
|
||||||
@ -121,18 +128,6 @@ class FieldDataset(Dataset):
|
|||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
|
|
||||||
if self.augment:
|
|
||||||
flip_axes = torch.randint(2, (self.ndim,), dtype=torch.bool)
|
|
||||||
flip_axes = torch.arange(self.ndim)[flip_axes]
|
|
||||||
|
|
||||||
in_fields = flip(in_fields, flip_axes, self.ndim)
|
|
||||||
tgt_fields = flip(tgt_fields, flip_axes, self.ndim)
|
|
||||||
|
|
||||||
perm_axes = torch.randperm(self.ndim)
|
|
||||||
|
|
||||||
in_fields = perm(in_fields, perm_axes, self.ndim)
|
|
||||||
tgt_fields = perm(tgt_fields, perm_axes, self.ndim)
|
|
||||||
|
|
||||||
if self.in_norms is not None:
|
if self.in_norms is not None:
|
||||||
for norm, x in zip(self.in_norms, in_fields):
|
for norm, x in zip(self.in_norms, in_fields):
|
||||||
norm(x)
|
norm(x)
|
||||||
@ -140,6 +135,21 @@ class FieldDataset(Dataset):
|
|||||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||||
norm(x)
|
norm(x)
|
||||||
|
|
||||||
|
if self.augment:
|
||||||
|
in_fields, flip_axes = flip(in_fields, None, self.ndim)
|
||||||
|
tgt_fields, flip_axes = flip(tgt_fields, flip_axes, self.ndim)
|
||||||
|
|
||||||
|
in_fields, perm_axes = perm(in_fields, None, self.ndim)
|
||||||
|
tgt_fields, perm_axes = perm(tgt_fields, perm_axes, self.ndim)
|
||||||
|
|
||||||
|
if self.aug_add is not None:
|
||||||
|
add_fac = add(in_fields, None, self.aug_add)
|
||||||
|
add_fac = add(tgt_fields, add_fac, self.aug_add)
|
||||||
|
|
||||||
|
if self.aug_mul is not None:
|
||||||
|
mul_fac = mul(in_fields, None, self.aug_mul)
|
||||||
|
mul_fac = mul(tgt_fields, mul_fac, self.aug_mul)
|
||||||
|
|
||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
|
||||||
@ -159,7 +169,11 @@ def crop(fields, start, crop, pad):
|
|||||||
|
|
||||||
|
|
||||||
def flip(fields, axes, ndim):
|
def flip(fields, axes, ndim):
|
||||||
assert ndim > 1, 'flipping is ambiguous for 1D vectors'
|
assert ndim > 1, 'flipping is ambiguous for 1D scalars/vectors'
|
||||||
|
|
||||||
|
if axes is None:
|
||||||
|
axes = torch.randint(2, (ndim,), dtype=torch.bool)
|
||||||
|
axes = torch.arange(ndim)[axes]
|
||||||
|
|
||||||
new_fields = []
|
new_fields = []
|
||||||
for x in fields:
|
for x in fields:
|
||||||
@ -171,12 +185,15 @@ def flip(fields, axes, ndim):
|
|||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
return new_fields
|
return new_fields, axes
|
||||||
|
|
||||||
|
|
||||||
def perm(fields, axes, ndim):
|
def perm(fields, axes, ndim):
|
||||||
assert ndim > 1, 'permutation is not necessary for 1D fields'
|
assert ndim > 1, 'permutation is not necessary for 1D fields'
|
||||||
|
|
||||||
|
if axes is None:
|
||||||
|
axes = torch.randperm(ndim)
|
||||||
|
|
||||||
new_fields = []
|
new_fields = []
|
||||||
for x in fields:
|
for x in fields:
|
||||||
if x.shape[0] == ndim: # permutate vector components
|
if x.shape[0] == ndim: # permutate vector components
|
||||||
@ -187,4 +204,28 @@ def perm(fields, axes, ndim):
|
|||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
return new_fields
|
return new_fields, axes
|
||||||
|
|
||||||
|
|
||||||
|
def add(fields, fac, std):
|
||||||
|
if fac is None:
|
||||||
|
x = fields[0]
|
||||||
|
fac = torch.zeros((x.shape[0],) + (1,) * (x.dim() - 1))
|
||||||
|
fac.normal_(mean=0, std=std)
|
||||||
|
|
||||||
|
for x in fields:
|
||||||
|
x += fac
|
||||||
|
|
||||||
|
return fac
|
||||||
|
|
||||||
|
|
||||||
|
def mul(fields, fac, std):
|
||||||
|
if fac is None:
|
||||||
|
x = fields[0]
|
||||||
|
fac = torch.ones((x.shape[0],) + (1,) * (x.dim() - 1))
|
||||||
|
fac.log_normal_(mean=0, std=std)
|
||||||
|
|
||||||
|
for x in fields:
|
||||||
|
x *= fac
|
||||||
|
|
||||||
|
return fac
|
||||||
|
@ -19,6 +19,8 @@ def test(args):
|
|||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
augment=False,
|
augment=False,
|
||||||
|
aug_add=None,
|
||||||
|
aug_mul=None,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
|
@ -63,6 +63,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
augment=args.augment,
|
augment=args.augment,
|
||||||
|
aug_add=args.aug_add,
|
||||||
|
aug_mul=args.aug_mul,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
@ -92,6 +94,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
in_norms=args.in_norms,
|
in_norms=args.in_norms,
|
||||||
tgt_norms=args.tgt_norms,
|
tgt_norms=args.tgt_norms,
|
||||||
augment=False,
|
augment=False,
|
||||||
|
aug_add=None,
|
||||||
|
aug_mul=None,
|
||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
|
Loading…
Reference in New Issue
Block a user