From e20f3194a574a6f7946fbb70dc8338e8d2d10f07 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 5 Feb 2020 21:09:39 -0500 Subject: [PATCH] Change linear interpolation to nearest in super-resolution --- map2map/data/fields.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 10e05b6..ec4055d 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -162,20 +162,17 @@ def crop(fields, start, crop, pad, scale_factor=1): for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)): begin, end = i - p0, i + N + p1 - if scale_factor > 1: # add buffer for linear interpolation - begin, end = begin - 1, end + 1 - x = x.take(range(begin, end), axis=1 + d, mode='wrap') if scale_factor > 1: x = torch.from_numpy(x).unsqueeze(0) - x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear') + x = F.interpolate(x, scale_factor=scale_factor, mode='nearest') x = x.numpy().squeeze(0) - # remove buffer and excess padding + # remove excess padding for d, (N, (p0, p1)) in enumerate(zip(crop, pad)): - begin = scale_factor + (scale_factor - 1) * p0 - end = scale_factor * (N + p0 + 1) + p1 + begin = (scale_factor - 1) * p0 + end = scale_factor * (N + p0) + p1 x = x.take(range(begin, end), axis=1 + d) new_fields.append(x)