Add tgt_pad, rename pad to in_pad
tgt_pad can be useful for scale_factor > 1
This commit is contained in:
parent
39ad59436e
commit
154376d95a
4 changed files with 21 additions and 13 deletions
|
@ -50,9 +50,12 @@ def add_common_args(parser):
|
|||
'corner to the origin')
|
||||
parser.add_argument('--crop-step', type=int,
|
||||
help='spacing between crops. Default is the crop size')
|
||||
parser.add_argument('--pad', default=0, type=int,
|
||||
parser.add_argument('--in-pad', '--pad', default=0, type=int,
|
||||
help='size to pad the input data beyond the crop size, assuming '
|
||||
'periodic boundary condition')
|
||||
parser.add_argument('--tgt-pad', default=0, type=int,
|
||||
help='size to pad the target data beyond the crop size, assuming '
|
||||
'periodic boundary condition, useful for super-resolution')
|
||||
parser.add_argument('--scale-factor', default=1, type=int,
|
||||
help='upsampling factor for super-resolution, in which case '
|
||||
'crop and pad are sizes of the input resolution')
|
||||
|
|
|
@ -43,7 +43,7 @@ class FieldDataset(Dataset):
|
|||
in_norms=None, tgt_norms=None, callback_at=None,
|
||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||
pad=0, scale_factor=1):
|
||||
in_pad=0, tgt_pad=0, scale_factor=1):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
|
||||
|
@ -51,7 +51,7 @@ class FieldDataset(Dataset):
|
|||
self.tgt_files = list(zip(* tgt_file_lists))
|
||||
|
||||
assert len(self.in_files) == len(self.tgt_files), \
|
||||
'number of input and target fields do not match'
|
||||
'number of input and target fields do not match'
|
||||
self.nfile = len(self.in_files)
|
||||
|
||||
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
|
||||
|
@ -67,12 +67,12 @@ class FieldDataset(Dataset):
|
|||
|
||||
if in_norms is not None:
|
||||
assert len(in_patterns) == len(in_norms), \
|
||||
'numbers of input normalization functions and fields do not match'
|
||||
'numbers of input normalization functions and fields do not match'
|
||||
self.in_norms = in_norms
|
||||
|
||||
if tgt_norms is not None:
|
||||
assert len(tgt_patterns) == len(tgt_norms), \
|
||||
'numbers of target normalization functions and fields do not match'
|
||||
'numbers of target normalization functions and fields do not match'
|
||||
self.tgt_norms = tgt_norms
|
||||
|
||||
self.callback_at = callback_at
|
||||
|
@ -110,11 +110,13 @@ class FieldDataset(Dataset):
|
|||
)], axis=-1).reshape(-1, self.ndim)
|
||||
self.ncrop = len(self.anchors)
|
||||
|
||||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
||||
assert isinstance(in_pad, int) and isinstance(tgt_pad, int), \
|
||||
'only support symmetric padding for now'
|
||||
self.in_pad = np.broadcast_to(in_pad, (self.ndim, 2))
|
||||
self.tgt_pad = np.broadcast_to(tgt_pad, (self.ndim, 2))
|
||||
|
||||
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
||||
'only support integer upsampling'
|
||||
'only support integer upsampling'
|
||||
if scale_factor > 1:
|
||||
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
||||
if any(self.size * scale_factor != tgt_size):
|
||||
|
@ -138,10 +140,10 @@ class FieldDataset(Dataset):
|
|||
if shift is not None:
|
||||
anchor[d] += torch.randint(int(shift), (1,))
|
||||
|
||||
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
|
||||
in_fields = crop(in_fields, anchor, self.crop, self.in_pad, self.size)
|
||||
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||
self.crop * self.scale_factor,
|
||||
np.zeros_like(self.pad),
|
||||
self.tgt_pad,
|
||||
self.size * self.scale_factor)
|
||||
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
|
|
|
@ -28,7 +28,8 @@ def test(args):
|
|||
crop_start=args.crop_start,
|
||||
crop_stop=args.crop_stop,
|
||||
crop_step=args.crop_step,
|
||||
pad=args.pad,
|
||||
in_pad=args.in_pad,
|
||||
tgt_pad=args.tgt_pad,
|
||||
scale_factor=args.scale_factor,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
|
|
|
@ -72,7 +72,8 @@ def gpu_worker(local_rank, node, args):
|
|||
crop_start=args.crop_start,
|
||||
crop_stop=args.crop_stop,
|
||||
crop_step=args.crop_step,
|
||||
pad=args.pad,
|
||||
in_pad=args.in_pad,
|
||||
tgt_pad=args.tgt_pad,
|
||||
scale_factor=args.scale_factor,
|
||||
)
|
||||
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
||||
|
@ -102,7 +103,8 @@ def gpu_worker(local_rank, node, args):
|
|||
crop_start=args.crop_start,
|
||||
crop_stop=args.crop_stop,
|
||||
crop_step=args.crop_step,
|
||||
pad=args.pad,
|
||||
in_pad=args.in_pad,
|
||||
tgt_pad=args.tgt_pad,
|
||||
scale_factor=args.scale_factor,
|
||||
)
|
||||
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
||||
|
|
Loading…
Reference in a new issue