Add unequal/asymmetric cropping, padding, and aug_shift
This commit is contained in:
parent
183a223ee6
commit
fd1cdb0ce7
@ -40,22 +40,27 @@ def add_common_args(parser):
|
||||
'of input normalization functions')
|
||||
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
||||
'of target normalization functions')
|
||||
parser.add_argument('--crop', type=int,
|
||||
parser.add_argument('--crop', type=int_tuple,
|
||||
help='size to crop the input and target data. Default is the '
|
||||
'field size')
|
||||
parser.add_argument('--crop-start', type=int,
|
||||
help='starting point of the first crop. Default is the origin')
|
||||
parser.add_argument('--crop-stop', type=int,
|
||||
'field size. Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--crop-start', type=int_tuple,
|
||||
help='starting point of the first crop. Default is the origin. '
|
||||
'Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--crop-stop', type=int_tuple,
|
||||
help='stopping point of the last crop. Default is the opposite '
|
||||
'corner to the origin')
|
||||
parser.add_argument('--crop-step', type=int,
|
||||
help='spacing between crops. Default is the crop size')
|
||||
parser.add_argument('--in-pad', '--pad', default=0, type=int,
|
||||
'corner to the origin. Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--crop-step', type=int_tuple,
|
||||
help='spacing between crops. Default is the crop size. '
|
||||
'Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple,
|
||||
help='size to pad the input data beyond the crop size, assuming '
|
||||
'periodic boundary condition')
|
||||
parser.add_argument('--tgt-pad', default=0, type=int,
|
||||
'periodic boundary condition. Comma-sep. list of 1, d, or dx2 '
|
||||
'integers, to pad equally along all axes, symmetrically on each, '
|
||||
'or by the specified size on every boundary, respectively')
|
||||
parser.add_argument('--tgt-pad', default=0, type=int_tuple,
|
||||
help='size to pad the target data beyond the crop size, assuming '
|
||||
'periodic boundary condition, useful for super-resolution')
|
||||
'periodic boundary condition, useful for super-resolution. '
|
||||
'Comma-sep. list with the same format as --in-pad')
|
||||
parser.add_argument('--scale-factor', default=1, type=int,
|
||||
help='upsampling factor for super-resolution, in which case '
|
||||
'crop and pad are sizes of the input resolution')
|
||||
@ -101,10 +106,11 @@ def add_train_args(parser):
|
||||
help='comma-sep. list of glob patterns for validation target data')
|
||||
parser.add_argument('--augment', action='store_true',
|
||||
help='enable data augmentation of axis flipping and permutation')
|
||||
parser.add_argument('--aug-shift', type=int,
|
||||
help='data augmentation by shifting [0, aug_shift) pixels, '
|
||||
parser.add_argument('--aug-shift', type=int_tuple,
|
||||
help='data augmentation by shifting cropping by [0, aug_shift) pixels, '
|
||||
'useful for models that treat neighboring pixels differently, '
|
||||
'e.g. with strided convolutions')
|
||||
'e.g. with strided convolutions. '
|
||||
'Comma-sep. list of 1 or d integers')
|
||||
parser.add_argument('--aug-add', type=float,
|
||||
help='additive data augmentation, (normal) std, '
|
||||
'same factor for all fields')
|
||||
@ -164,14 +170,13 @@ def str_list(s):
|
||||
return s.split(',')
|
||||
|
||||
|
||||
#def int_tuple(t):
|
||||
# t = t.split(',')
|
||||
# t = tuple(int(i) for i in t)
|
||||
# if len(t) == 1:
|
||||
# t = t[0]
|
||||
# elif len(t) != 6:
|
||||
# raise ValueError('size must be int or 6-tuple')
|
||||
# return t
|
||||
def int_tuple(s):
|
||||
t = s.split(',')
|
||||
t = tuple(int(i) for i in t)
|
||||
if len(t) == 1:
|
||||
return t[0]
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def set_common_args(args):
|
||||
|
@ -109,14 +109,20 @@ class FieldDataset(Dataset):
|
||||
)], axis=-1).reshape(-1, self.ndim)
|
||||
self.ncrop = len(self.anchors)
|
||||
|
||||
assert isinstance(in_pad, int) and isinstance(tgt_pad, int), \
|
||||
'only support symmetric padding for now'
|
||||
self.in_pad = np.broadcast_to(in_pad, (self.ndim, 2))
|
||||
self.tgt_pad = np.broadcast_to(tgt_pad, (self.ndim, 2))
|
||||
def format_pad(pad, ndim):
|
||||
if isinstance(pad, int):
|
||||
pad = np.broadcast_to(pad, ndim * 2)
|
||||
elif isinstance(pad, tuple) and len(pad) == ndim:
|
||||
pad = np.repeat(pad, 2)
|
||||
elif isinstance(pad, tuple) and len(pad) == ndim * 2:
|
||||
pad = np.array(pad)
|
||||
else:
|
||||
raise ValueError('pad and ndim mismatch')
|
||||
return pad.reshape(ndim, 2)
|
||||
self.in_pad = format_pad(in_pad, self.ndim)
|
||||
self.tgt_pad = format_pad(tgt_pad, self.ndim)
|
||||
|
||||
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
||||
'only support integer upsampling'
|
||||
if scale_factor > 1:
|
||||
if scale_factor != 1:
|
||||
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
||||
if any(self.size * scale_factor != tgt_size):
|
||||
raise ValueError('input size x scale factor != target size')
|
||||
|
Loading…
Reference in New Issue
Block a user