Add tgt_pad, rename pad to in_pad

tgt_pad can be useful for scale_factor > 1
This commit is contained in:
Yin Li 2020-09-12 13:35:02 -04:00
parent 39ad59436e
commit 154376d95a
4 changed files with 21 additions and 13 deletions

View File

@ -50,9 +50,12 @@ def add_common_args(parser):
'corner to the origin')
parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size')
parser.add_argument('--pad', default=0, type=int,
parser.add_argument('--in-pad', '--pad', default=0, type=int,
help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--tgt-pad', default=0, type=int,
help='size to pad the target data beyond the crop size, assuming '
'periodic boundary condition, useful for super-resolution')
parser.add_argument('--scale-factor', default=1, type=int,
help='upsampling factor for super-resolution, in which case '
'crop and pad are sizes of the input resolution')

View File

@ -43,7 +43,7 @@ class FieldDataset(Dataset):
in_norms=None, tgt_norms=None, callback_at=None,
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
crop=None, crop_start=None, crop_stop=None, crop_step=None,
pad=0, scale_factor=1):
in_pad=0, tgt_pad=0, scale_factor=1):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))
@ -110,8 +110,10 @@ class FieldDataset(Dataset):
)], axis=-1).reshape(-1, self.ndim)
self.ncrop = len(self.anchors)
assert isinstance(pad, int), 'only support symmetric padding for now'
self.pad = np.broadcast_to(pad, (self.ndim, 2))
assert isinstance(in_pad, int) and isinstance(tgt_pad, int), \
'only support symmetric padding for now'
self.in_pad = np.broadcast_to(in_pad, (self.ndim, 2))
self.tgt_pad = np.broadcast_to(tgt_pad, (self.ndim, 2))
assert isinstance(scale_factor, int) and scale_factor >= 1, \
'only support integer upsampling'
@ -138,10 +140,10 @@ class FieldDataset(Dataset):
if shift is not None:
anchor[d] += torch.randint(int(shift), (1,))
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
in_fields = crop(in_fields, anchor, self.crop, self.in_pad, self.size)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor,
np.zeros_like(self.pad),
self.tgt_pad,
self.size * self.scale_factor)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]

View File

@ -28,7 +28,8 @@ def test(args):
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
test_loader = DataLoader(

View File

@ -72,7 +72,8 @@ def gpu_worker(local_rank, node, args):
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
@ -102,7 +103,8 @@ def gpu_worker(local_rank, node, args):
crop_start=args.crop_start,
crop_stop=args.crop_stop,
crop_step=args.crop_step,
pad=args.pad,
in_pad=args.in_pad,
tgt_pad=args.tgt_pad,
scale_factor=args.scale_factor,
)
val_sampler = DistFieldSampler(val_dataset, shuffle=False,