Fix cropping bug when scale_factor>1

This commit is contained in:
Yin Li 2020-08-23 14:46:57 -05:00
parent 670364e54c
commit 6c67eaa788

View file

@ -117,6 +117,10 @@ class FieldDataset(Dataset):
assert isinstance(scale_factor, int) and scale_factor >= 1, \
'only support integer upsampling'
if scale_factor > 1:
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
if any(self.size * scale_factor != tgt_size):
raise ValueError('input size x scale factor != target size')
self.scale_factor = scale_factor
self.nsample = self.nfile * self.ncrop
@ -134,12 +138,13 @@ class FieldDataset(Dataset):
for d, shift in enumerate(self.aug_shift):
if shift is not None:
anchor[d] += torch.randint(shift, (1,))
anchor[d] += torch.randint(int(shift), (1,))
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor,
np.zeros_like(self.pad), self.size)
np.zeros_like(self.pad),
self.size * self.scale_factor)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]