Fix cropping bug when scale_factor>1
This commit is contained in:
parent
670364e54c
commit
6c67eaa788
@ -117,6 +117,10 @@ class FieldDataset(Dataset):
|
||||
|
||||
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
||||
'only support integer upsampling'
|
||||
if scale_factor > 1:
|
||||
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
||||
if any(self.size * scale_factor != tgt_size):
|
||||
raise ValueError('input size x scale factor != target size')
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
self.nsample = self.nfile * self.ncrop
|
||||
@ -134,12 +138,13 @@ class FieldDataset(Dataset):
|
||||
|
||||
for d, shift in enumerate(self.aug_shift):
|
||||
if shift is not None:
|
||||
anchor[d] += torch.randint(shift, (1,))
|
||||
anchor[d] += torch.randint(int(shift), (1,))
|
||||
|
||||
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
|
||||
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||
self.crop * self.scale_factor,
|
||||
np.zeros_like(self.pad), self.size)
|
||||
np.zeros_like(self.pad),
|
||||
self.size * self.scale_factor)
|
||||
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||
|
Loading…
Reference in New Issue
Block a user