Add tgt_pad, rename pad to in_pad
tgt_pad can be useful for scale_factor > 1
This commit is contained in:
parent
39ad59436e
commit
154376d95a
@ -50,9 +50,12 @@ def add_common_args(parser):
|
|||||||
'corner to the origin')
|
'corner to the origin')
|
||||||
parser.add_argument('--crop-step', type=int,
|
parser.add_argument('--crop-step', type=int,
|
||||||
help='spacing between crops. Default is the crop size')
|
help='spacing between crops. Default is the crop size')
|
||||||
parser.add_argument('--pad', default=0, type=int,
|
parser.add_argument('--in-pad', '--pad', default=0, type=int,
|
||||||
help='size to pad the input data beyond the crop size, assuming '
|
help='size to pad the input data beyond the crop size, assuming '
|
||||||
'periodic boundary condition')
|
'periodic boundary condition')
|
||||||
|
parser.add_argument('--tgt-pad', default=0, type=int,
|
||||||
|
help='size to pad the target data beyond the crop size, assuming '
|
||||||
|
'periodic boundary condition, useful for super-resolution')
|
||||||
parser.add_argument('--scale-factor', default=1, type=int,
|
parser.add_argument('--scale-factor', default=1, type=int,
|
||||||
help='upsampling factor for super-resolution, in which case '
|
help='upsampling factor for super-resolution, in which case '
|
||||||
'crop and pad are sizes of the input resolution')
|
'crop and pad are sizes of the input resolution')
|
||||||
|
@ -43,7 +43,7 @@ class FieldDataset(Dataset):
|
|||||||
in_norms=None, tgt_norms=None, callback_at=None,
|
in_norms=None, tgt_norms=None, callback_at=None,
|
||||||
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
augment=False, aug_shift=None, aug_add=None, aug_mul=None,
|
||||||
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
crop=None, crop_start=None, crop_stop=None, crop_step=None,
|
||||||
pad=0, scale_factor=1):
|
in_pad=0, tgt_pad=0, scale_factor=1):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class FieldDataset(Dataset):
|
|||||||
self.tgt_files = list(zip(* tgt_file_lists))
|
self.tgt_files = list(zip(* tgt_file_lists))
|
||||||
|
|
||||||
assert len(self.in_files) == len(self.tgt_files), \
|
assert len(self.in_files) == len(self.tgt_files), \
|
||||||
'number of input and target fields do not match'
|
'number of input and target fields do not match'
|
||||||
self.nfile = len(self.in_files)
|
self.nfile = len(self.in_files)
|
||||||
|
|
||||||
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
|
assert self.nfile > 0, 'file not found for {}'.format(in_patterns)
|
||||||
@ -67,12 +67,12 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
if in_norms is not None:
|
if in_norms is not None:
|
||||||
assert len(in_patterns) == len(in_norms), \
|
assert len(in_patterns) == len(in_norms), \
|
||||||
'numbers of input normalization functions and fields do not match'
|
'numbers of input normalization functions and fields do not match'
|
||||||
self.in_norms = in_norms
|
self.in_norms = in_norms
|
||||||
|
|
||||||
if tgt_norms is not None:
|
if tgt_norms is not None:
|
||||||
assert len(tgt_patterns) == len(tgt_norms), \
|
assert len(tgt_patterns) == len(tgt_norms), \
|
||||||
'numbers of target normalization functions and fields do not match'
|
'numbers of target normalization functions and fields do not match'
|
||||||
self.tgt_norms = tgt_norms
|
self.tgt_norms = tgt_norms
|
||||||
|
|
||||||
self.callback_at = callback_at
|
self.callback_at = callback_at
|
||||||
@ -110,11 +110,13 @@ class FieldDataset(Dataset):
|
|||||||
)], axis=-1).reshape(-1, self.ndim)
|
)], axis=-1).reshape(-1, self.ndim)
|
||||||
self.ncrop = len(self.anchors)
|
self.ncrop = len(self.anchors)
|
||||||
|
|
||||||
assert isinstance(pad, int), 'only support symmetric padding for now'
|
assert isinstance(in_pad, int) and isinstance(tgt_pad, int), \
|
||||||
self.pad = np.broadcast_to(pad, (self.ndim, 2))
|
'only support symmetric padding for now'
|
||||||
|
self.in_pad = np.broadcast_to(in_pad, (self.ndim, 2))
|
||||||
|
self.tgt_pad = np.broadcast_to(tgt_pad, (self.ndim, 2))
|
||||||
|
|
||||||
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
||||||
'only support integer upsampling'
|
'only support integer upsampling'
|
||||||
if scale_factor > 1:
|
if scale_factor > 1:
|
||||||
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
||||||
if any(self.size * scale_factor != tgt_size):
|
if any(self.size * scale_factor != tgt_size):
|
||||||
@ -138,10 +140,10 @@ class FieldDataset(Dataset):
|
|||||||
if shift is not None:
|
if shift is not None:
|
||||||
anchor[d] += torch.randint(int(shift), (1,))
|
anchor[d] += torch.randint(int(shift), (1,))
|
||||||
|
|
||||||
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
|
in_fields = crop(in_fields, anchor, self.crop, self.in_pad, self.size)
|
||||||
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor,
|
self.crop * self.scale_factor,
|
||||||
np.zeros_like(self.pad),
|
self.tgt_pad,
|
||||||
self.size * self.scale_factor)
|
self.size * self.scale_factor)
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
|
@ -28,7 +28,8 @@ def test(args):
|
|||||||
crop_start=args.crop_start,
|
crop_start=args.crop_start,
|
||||||
crop_stop=args.crop_stop,
|
crop_stop=args.crop_stop,
|
||||||
crop_step=args.crop_step,
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
in_pad=args.in_pad,
|
||||||
|
tgt_pad=args.tgt_pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
|
@ -72,7 +72,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
crop_start=args.crop_start,
|
crop_start=args.crop_start,
|
||||||
crop_stop=args.crop_stop,
|
crop_stop=args.crop_stop,
|
||||||
crop_step=args.crop_step,
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
in_pad=args.in_pad,
|
||||||
|
tgt_pad=args.tgt_pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
)
|
)
|
||||||
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
train_sampler = DistFieldSampler(train_dataset, shuffle=True,
|
||||||
@ -102,7 +103,8 @@ def gpu_worker(local_rank, node, args):
|
|||||||
crop_start=args.crop_start,
|
crop_start=args.crop_start,
|
||||||
crop_stop=args.crop_stop,
|
crop_stop=args.crop_stop,
|
||||||
crop_step=args.crop_step,
|
crop_step=args.crop_step,
|
||||||
pad=args.pad,
|
in_pad=args.in_pad,
|
||||||
|
tgt_pad=args.tgt_pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
)
|
)
|
||||||
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
val_sampler = DistFieldSampler(val_dataset, shuffle=False,
|
||||||
|
Loading…
Reference in New Issue
Block a user