From fd1cdb0ce73808cb117de6188497426934c9f57f Mon Sep 17 00:00:00 2001 From: Yin Li Date: Fri, 12 Mar 2021 15:25:02 -0500 Subject: [PATCH] Add unequal/asymmetric cropping, padding, and aug_shift --- map2map/args.py | 51 +++++++++++++++++++++++------------------- map2map/data/fields.py | 20 +++++++++++------ 2 files changed, 41 insertions(+), 30 deletions(-) diff --git a/map2map/args.py b/map2map/args.py index ad440c7..f1c1c9b 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -40,22 +40,27 @@ def add_common_args(parser): 'of input normalization functions') parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list ' 'of target normalization functions') - parser.add_argument('--crop', type=int, + parser.add_argument('--crop', type=int_tuple, help='size to crop the input and target data. Default is the ' - 'field size') - parser.add_argument('--crop-start', type=int, - help='starting point of the first crop. Default is the origin') - parser.add_argument('--crop-stop', type=int, + 'field size. Comma-sep. list of 1 or d integers') + parser.add_argument('--crop-start', type=int_tuple, + help='starting point of the first crop. Default is the origin. ' + 'Comma-sep. list of 1 or d integers') + parser.add_argument('--crop-stop', type=int_tuple, help='stopping point of the last crop. Default is the opposite ' - 'corner to the origin') - parser.add_argument('--crop-step', type=int, - help='spacing between crops. Default is the crop size') - parser.add_argument('--in-pad', '--pad', default=0, type=int, + 'corner to the origin. Comma-sep. list of 1 or d integers') + parser.add_argument('--crop-step', type=int_tuple, + help='spacing between crops. Default is the crop size. ' + 'Comma-sep. list of 1 or d integers') + parser.add_argument('--in-pad', '--pad', default=0, type=int_tuple, help='size to pad the input data beyond the crop size, assuming ' - 'periodic boundary condition') - parser.add_argument('--tgt-pad', default=0, type=int, + 'periodic boundary condition. Comma-sep. list of 1, d, or dx2 ' + 'integers, to pad equally along all axes, symmetrically on each, ' + 'or by the specified size on every boundary, respectively') + parser.add_argument('--tgt-pad', default=0, type=int_tuple, help='size to pad the target data beyond the crop size, assuming ' - 'periodic boundary condition, useful for super-resolution') + 'periodic boundary condition, useful for super-resolution. ' + 'Comma-sep. list with the same format as --in-pad') parser.add_argument('--scale-factor', default=1, type=int, help='upsampling factor for super-resolution, in which case ' 'crop and pad are sizes of the input resolution') @@ -101,10 +106,11 @@ def add_train_args(parser): help='comma-sep. list of glob patterns for validation target data') parser.add_argument('--augment', action='store_true', help='enable data augmentation of axis flipping and permutation') - parser.add_argument('--aug-shift', type=int, - help='data augmentation by shifting [0, aug_shift) pixels, ' + parser.add_argument('--aug-shift', type=int_tuple, + help='data augmentation by shifting cropping by [0, aug_shift) pixels, ' 'useful for models that treat neighboring pixels differently, ' - 'e.g. with strided convolutions') + 'e.g. with strided convolutions. ' + 'Comma-sep. list of 1 or d integers') parser.add_argument('--aug-add', type=float, help='additive data augmentation, (normal) std, ' 'same factor for all fields') @@ -164,14 +170,13 @@ def str_list(s): return s.split(',') -#def int_tuple(t): -# t = t.split(',') -# t = tuple(int(i) for i in t) -# if len(t) == 1: -# t = t[0] -# elif len(t) != 6: -# raise ValueError('size must be int or 6-tuple') -# return t +def int_tuple(s): + t = s.split(',') + t = tuple(int(i) for i in t) + if len(t) == 1: + return t[0] + else: + return t def set_common_args(args): diff --git a/map2map/data/fields.py b/map2map/data/fields.py index c610a3a..798a67b 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -109,14 +109,20 @@ class FieldDataset(Dataset): )], axis=-1).reshape(-1, self.ndim) self.ncrop = len(self.anchors) - assert isinstance(in_pad, int) and isinstance(tgt_pad, int), \ - 'only support symmetric padding for now' - self.in_pad = np.broadcast_to(in_pad, (self.ndim, 2)) - self.tgt_pad = np.broadcast_to(tgt_pad, (self.ndim, 2)) + def format_pad(pad, ndim): + if isinstance(pad, int): + pad = np.broadcast_to(pad, ndim * 2) + elif isinstance(pad, tuple) and len(pad) == ndim: + pad = np.repeat(pad, 2) + elif isinstance(pad, tuple) and len(pad) == ndim * 2: + pad = np.array(pad) + else: + raise ValueError('pad and ndim mismatch') + return pad.reshape(ndim, 2) + self.in_pad = format_pad(in_pad, self.ndim) + self.tgt_pad = format_pad(tgt_pad, self.ndim) - assert isinstance(scale_factor, int) and scale_factor >= 1, \ - 'only support integer upsampling' - if scale_factor > 1: + if scale_factor != 1: tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:] if any(self.size * scale_factor != tgt_size): raise ValueError('input size x scale factor != target size')