Add unequal/asymmetric cropping, padding, and aug_shift

This commit is contained in:
Yin Li 2021-03-12 15:25:02 -05:00
parent 183a223ee6
commit fd1cdb0ce7
2 changed files with 41 additions and 30 deletions

View file

@ -40,22 +40,27 @@ def add_common_args(parser):
'of input normalization functions')
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
'of target normalization functions')
parser.add_argument('--crop', type=int,
parser.add_argument('--crop', type=int_tuple,
help='size to crop the input and target data. Default is the '
'field size')
parser.add_argument('--crop-start', type=int,
help='starting point of the first crop. Default is the origin')
parser.add_argument('--crop-stop', type=int,
'field size. Comma-sep. list of 1 or d integers')
parser.add_argument('--crop-start', type=int_tuple,
help='starting point of the first crop. Default is the origin. '
'Comma-sep. list of 1 or d integers')
parser.add_argument('--crop-stop', type=int_tuple,
help='stopping point of the last crop. Default is the opposite '
'corner to the origin')
parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size')
parser.add_argument('--in-pad', '--pad', default=0, type=int,
'corner to the origin. Comma-sep. list of 1 or d integers')
parser.add_argument('--crop-step', type=int_tuple,
help='spacing between crops. Default is the crop size. '
'Comma-sep. list of 1 or d integers')
parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--tgt-pad', default=0, type=int,
'periodic boundary condition. Comma-sep. list of 1, d, or dx2 '
'integers, to pad equally along all axes, symmetrically on each, '
'or by the specified size on every boundary, respectively')
parser.add_argument('--tgt-pad', default=0, type=int_tuple,
help='size to pad the target data beyond the crop size, assuming '
'periodic boundary condition, useful for super-resolution')
'periodic boundary condition, useful for super-resolution. '
'Comma-sep. list with the same format as --in-pad')
parser.add_argument('--scale-factor', default=1, type=int,
help='upsampling factor for super-resolution, in which case '
'crop and pad are sizes of the input resolution')
@ -101,10 +106,11 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true',
help='enable data augmentation of axis flipping and permutation')
parser.add_argument('--aug-shift', type=int,
help='data augmentation by shifting [0, aug_shift) pixels, '
parser.add_argument('--aug-shift', type=int_tuple,
help='data augmentation by shifting cropping by [0, aug_shift) pixels, '
'useful for models that treat neighboring pixels differently, '
'e.g. with strided convolutions')
'e.g. with strided convolutions. '
'Comma-sep. list of 1 or d integers')
parser.add_argument('--aug-add', type=float,
help='additive data augmentation, (normal) std, '
'same factor for all fields')
@ -164,14 +170,13 @@ def str_list(s):
return s.split(',')
#def int_tuple(t):
# t = t.split(',')
# t = tuple(int(i) for i in t)
# if len(t) == 1:
# t = t[0]
# elif len(t) != 6:
# raise ValueError('size must be int or 6-tuple')
# return t
def int_tuple(s):
t = s.split(',')
t = tuple(int(i) for i in t)
if len(t) == 1:
return t[0]
else:
return t
def set_common_args(args):

View file

@ -109,14 +109,20 @@ class FieldDataset(Dataset):
)], axis=-1).reshape(-1, self.ndim)
self.ncrop = len(self.anchors)
assert isinstance(in_pad, int) and isinstance(tgt_pad, int), \
'only support symmetric padding for now'
self.in_pad = np.broadcast_to(in_pad, (self.ndim, 2))
self.tgt_pad = np.broadcast_to(tgt_pad, (self.ndim, 2))
def format_pad(pad, ndim):
if isinstance(pad, int):
pad = np.broadcast_to(pad, ndim * 2)
elif isinstance(pad, tuple) and len(pad) == ndim:
pad = np.repeat(pad, 2)
elif isinstance(pad, tuple) and len(pad) == ndim * 2:
pad = np.array(pad)
else:
raise ValueError('pad and ndim mismatch')
return pad.reshape(ndim, 2)
self.in_pad = format_pad(in_pad, self.ndim)
self.tgt_pad = format_pad(tgt_pad, self.ndim)
assert isinstance(scale_factor, int) and scale_factor >= 1, \
'only support integer upsampling'
if scale_factor > 1:
if scale_factor != 1:
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
if any(self.size * scale_factor != tgt_size):
raise ValueError('input size x scale factor != target size')