Change linear interpolation to nearest in super-resolution

This commit is contained in:
Yin Li 2020-02-05 21:09:39 -05:00
parent 679f9f2545
commit e20f3194a5

View File

@ -162,20 +162,17 @@ def crop(fields, start, crop, pad, scale_factor=1):
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)): for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
begin, end = i - p0, i + N + p1 begin, end = i - p0, i + N + p1
if scale_factor > 1: # add buffer for linear interpolation
begin, end = begin - 1, end + 1
x = x.take(range(begin, end), axis=1 + d, mode='wrap') x = x.take(range(begin, end), axis=1 + d, mode='wrap')
if scale_factor > 1: if scale_factor > 1:
x = torch.from_numpy(x).unsqueeze(0) x = torch.from_numpy(x).unsqueeze(0)
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear') x = F.interpolate(x, scale_factor=scale_factor, mode='nearest')
x = x.numpy().squeeze(0) x = x.numpy().squeeze(0)
# remove buffer and excess padding # remove excess padding
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)): for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
begin = scale_factor + (scale_factor - 1) * p0 begin = (scale_factor - 1) * p0
end = scale_factor * (N + p0 + 1) + p1 end = scale_factor * (N + p0) + p1
x = x.take(range(begin, end), axis=1 + d) x = x.take(range(begin, end), axis=1 + d)
new_fields.append(x) new_fields.append(x)