Add tgt_pad, rename pad to in_pad

tgt_pad can be useful for scale_factor > 1
This commit is contained in:
Yin Li 2020-09-12 13:35:02 -04:00
parent 39ad59436e
commit 154376d95a
4 changed files with 21 additions and 13 deletions

View File

@ -50,9 +50,12 @@ def add_common_args(parser):
'corner to the origin')
parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size')
parser.add_argument('--pad', default=0, type=int,
parser.add_argument('--in-pad', '--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--tgt-pad', default=0, type=int,
help='size to pad the target data beyond the crop size, assuming '
'periodic boundary condition, useful for super-resolution')
parser.add_argument('--scale-factor', default=1, type=int,
help='upsampling factor for super-resolution, in which case '
'crop and pad are sizes of the input resolution')

View File

@ -43,7 +43,7 @@ class FieldDataset(Dataset):
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
pad=0, scale_factor=1):
in_pad=0, tgt_pad=0, scale_factor=1):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -51,7 +51,7 @@ class FieldDataset(Dataset):
self.tgt_files = list(zip(* tgt_file_lists))
assert len(self.in_files) == len(self.tgt_files), \
'number of input and target fields do not match'
'number of input and target fields do not match'
self.nfile = len(self.in_files)
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
@ -67,12 +67,12 @@ class FieldDataset(Dataset):
if in_norms is not None:
assert len(in_patterns) == len(in_norms), \
'numbers of input normalization functions and fields do not match'
'numbers of input normalization functions and fields do not match'
self.in_norms = in_norms
if tgt_norms is not None:
assert len(tgt_patterns) == len(tgt_norms), \
'numbers of target normalization functions and fields do not match'
'numbers of target normalization functions and fields do not match'
self.tgt_norms = tgt_norms
self.callback_at = callback_at
@ -110,11 +110,13 @@ class FieldDataset(Dataset):
)], axis=-1).reshape(-1, self.ndim)
self.ncrop = len(self.anchors)
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
assert isinstance(in_pad, int) and isinstance(tgt_pad, int), \
'only support symmetric padding for now'
self.in_pad = np.broadcast_to(in_pad, (self.ndim, 2))
self.tgt_pad = np.broadcast_to(tgt_pad, (self.ndim, 2))
assert isinstance(scale_factor, int) and scale_factor >= 1, \
'only support integer upsampling'
'only support integer upsampling'
if scale_factor > 1:
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
if any(self.size * scale_factor != tgt_size):
@ -138,10 +140,10 @@ class FieldDataset(Dataset):
if shift is not None:
anchor[d] += torch.randint(int(shift), (1,))
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
in_fields = crop(in_fields, anchor, self.crop, self.in_pad, self.size)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor,
np.zeros_like(self.pad),
self.tgt_pad,
self.size * self.scale_factor)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]

View File

@ -28,7 +28,8 @@ def test(args):
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
test_loader = DataLoader(

View File

@ -72,7 +72,8 @@ def gpu_worker(local_rank, node, args):
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
@ -102,7 +103,8 @@ def gpu_worker(local_rank, node, args):
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
val_sampler = DistFieldSampler(val_dataset, shuffle=False,