Add unequal/asymmetric cropping, padding, and aug_shift
This commit is contained in:
parent
183a223ee6
commit
fd1cdb0ce7
@ -40,22 +40,27 @@ def add_common_args(parser):
|
|||||||
'of input normalization functions')
|
'of input normalization functions')
|
||||||
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
||||||
'of target normalization functions')
|
'of target normalization functions')
|
||||||
parser.add_argument('--crop', type=int,
|
parser.add_argument('--crop', type=int_tuple,
|
||||||
help='size to crop the input and target data. Default is the '
|
help='size to crop the input and target data. Default is the '
|
||||||
'field size')
|
'field size. Comma-sep. list of 1 or d integers')
|
||||||
parser.add_argument('--crop-start', type=int,
|
parser.add_argument('--crop-start', type=int_tuple,
|
||||||
help='starting point of the first crop. Default is the origin')
|
help='starting point of the first crop. Default is the origin. '
|
||||||
parser.add_argument('--crop-stop', type=int,
|
'Comma-sep. list of 1 or d integers')
|
||||||
|
parser.add_argument('--crop-stop', type=int_tuple,
|
||||||
help='stopping point of the last crop. Default is the opposite '
|
help='stopping point of the last crop. Default is the opposite '
|
||||||
'corner to the origin')
|
'corner to the origin. Comma-sep. list of 1 or d integers')
|
||||||
parser.add_argument('--crop-step', type=int,
|
parser.add_argument('--crop-step', type=int_tuple,
|
||||||
help='spacing between crops. Default is the crop size')
|
help='spacing between crops. Default is the crop size. '
|
||||||
parser.add_argument('--in-pad', '--pad', default=0, type=int,
|
'Comma-sep. list of 1 or d integers')
|
||||||
|
parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple,
|
||||||
help='size to pad the input data beyond the crop size, assuming '
|
help='size to pad the input data beyond the crop size, assuming '
|
||||||
'periodic boundary condition')
|
'periodic boundary condition. Comma-sep. list of 1, d, or dx2 '
|
||||||
parser.add_argument('--tgt-pad', default=0, type=int,
|
'integers, to pad equally along all axes, symmetrically on each, '
|
||||||
|
'or by the specified size on every boundary, respectively')
|
||||||
|
parser.add_argument('--tgt-pad', default=0, type=int_tuple,
|
||||||
help='size to pad the target data beyond the crop size, assuming '
|
help='size to pad the target data beyond the crop size, assuming '
|
||||||
'periodic boundary condition, useful for super-resolution')
|
'periodic boundary condition, useful for super-resolution. '
|
||||||
|
'Comma-sep. list with the same format as --in-pad')
|
||||||
parser.add_argument('--scale-factor', default=1, type=int,
|
parser.add_argument('--scale-factor', default=1, type=int,
|
||||||
help='upsampling factor for super-resolution, in which case '
|
help='upsampling factor for super-resolution, in which case '
|
||||||
'crop and pad are sizes of the input resolution')
|
'crop and pad are sizes of the input resolution')
|
||||||
@ -101,10 +106,11 @@ def add_train_args(parser):
|
|||||||
help='comma-sep. list of glob patterns for validation target data')
|
help='comma-sep. list of glob patterns for validation target data')
|
||||||
parser.add_argument('--augment', action='store_true',
|
parser.add_argument('--augment', action='store_true',
|
||||||
help='enable data augmentation of axis flipping and permutation')
|
help='enable data augmentation of axis flipping and permutation')
|
||||||
parser.add_argument('--aug-shift', type=int,
|
parser.add_argument('--aug-shift', type=int_tuple,
|
||||||
help='data augmentation by shifting [0, aug_shift) pixels, '
|
help='data augmentation by shifting cropping by [0, aug_shift) pixels, '
|
||||||
'useful for models that treat neighboring pixels differently, '
|
'useful for models that treat neighboring pixels differently, '
|
||||||
'e.g. with strided convolutions')
|
'e.g. with strided convolutions. '
|
||||||
|
'Comma-sep. list of 1 or d integers')
|
||||||
parser.add_argument('--aug-add', type=float,
|
parser.add_argument('--aug-add', type=float,
|
||||||
help='additive data augmentation, (normal) std, '
|
help='additive data augmentation, (normal) std, '
|
||||||
'same factor for all fields')
|
'same factor for all fields')
|
||||||
@ -164,14 +170,13 @@ def str_list(s):
|
|||||||
return s.split(',')
|
return s.split(',')
|
||||||
|
|
||||||
|
|
||||||
#def int_tuple(t):
|
def int_tuple(s):
|
||||||
# t = t.split(',')
|
t = s.split(',')
|
||||||
# t = tuple(int(i) for i in t)
|
t = tuple(int(i) for i in t)
|
||||||
# if len(t) == 1:
|
if len(t) == 1:
|
||||||
# t = t[0]
|
return t[0]
|
||||||
# elif len(t) != 6:
|
else:
|
||||||
# raise ValueError('size must be int or 6-tuple')
|
return t
|
||||||
# return t
|
|
||||||
|
|
||||||
|
|
||||||
def set_common_args(args):
|
def set_common_args(args):
|
||||||
|
@ -109,14 +109,20 @@ class FieldDataset(Dataset):
|
|||||||
)], axis=-1).reshape(-1, self.ndim)
|
)], axis=-1).reshape(-1, self.ndim)
|
||||||
self.ncrop = len(self.anchors)
|
self.ncrop = len(self.anchors)
|
||||||
|
|
||||||
assert isinstance(in_pad, int) and isinstance(tgt_pad, int), \
|
def format_pad(pad, ndim):
|
||||||
'only support symmetric padding for now'
|
if isinstance(pad, int):
|
||||||
self.in_pad = np.broadcast_to(in_pad, (self.ndim, 2))
|
pad = np.broadcast_to(pad, ndim * 2)
|
||||||
self.tgt_pad = np.broadcast_to(tgt_pad, (self.ndim, 2))
|
elif isinstance(pad, tuple) and len(pad) == ndim:
|
||||||
|
pad = np.repeat(pad, 2)
|
||||||
|
elif isinstance(pad, tuple) and len(pad) == ndim * 2:
|
||||||
|
pad = np.array(pad)
|
||||||
|
else:
|
||||||
|
raise ValueError('pad and ndim mismatch')
|
||||||
|
return pad.reshape(ndim, 2)
|
||||||
|
self.in_pad = format_pad(in_pad, self.ndim)
|
||||||
|
self.tgt_pad = format_pad(tgt_pad, self.ndim)
|
||||||
|
|
||||||
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
if scale_factor != 1:
|
||||||
'only support integer upsampling'
|
|
||||||
if scale_factor > 1:
|
|
||||||
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
||||||
if any(self.size * scale_factor != tgt_size):
|
if any(self.size * scale_factor != tgt_size):
|
||||||
raise ValueError('input size x scale factor != target size')
|
raise ValueError('input size x scale factor != target size')
|
||||||
|
Loading…
Reference in New Issue
Block a user