Fix cropping bug when scale_factor>1
This commit is contained in:
parent
670364e54c
commit
6c67eaa788
@ -117,6 +117,10 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
assert isinstance(scale_factor, int) and scale_factor >= 1, \
|
||||||
'only support integer upsampling'
|
'only support integer upsampling'
|
||||||
|
if scale_factor > 1:
|
||||||
|
tgt_size = np.load(self.tgt_files[0][0], mmap_mode='r').shape[1:]
|
||||||
|
if any(self.size * scale_factor != tgt_size):
|
||||||
|
raise ValueError('input size x scale factor != target size')
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
self.nsample = self.nfile * self.ncrop
|
self.nsample = self.nfile * self.ncrop
|
||||||
@ -134,12 +138,13 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
for d, shift in enumerate(self.aug_shift):
|
for d, shift in enumerate(self.aug_shift):
|
||||||
if shift is not None:
|
if shift is not None:
|
||||||
anchor[d] += torch.randint(shift, (1,))
|
anchor[d] += torch.randint(int(shift), (1,))
|
||||||
|
|
||||||
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
|
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
|
||||||
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor,
|
self.crop * self.scale_factor,
|
||||||
np.zeros_like(self.pad), self.size)
|
np.zeros_like(self.pad),
|
||||||
|
self.size * self.scale_factor)
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
|
Loading…
Reference in New Issue
Block a user