From de71df51f5d2fbedb0259037ac86c8c16e2b9554 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Thu, 23 Jan 2020 12:34:10 -0500 Subject: [PATCH] Fix touch-up... --- map2map/data/fields.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 9db64f7..32d13b2 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -149,23 +149,22 @@ def crop(fields, start, crop, pad, scale_factor=1): new_fields = [] for x in fields: for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)): - begin, stop = i - p0, i + N + p1 + begin, end = i - p0, i + N + p1 if scale_factor > 1: # add buffer for linear interpolation - begin, stop = begin - 1, stop + 1 + begin, end = begin - 1, end + 1 - x = x.take(range(begin, stop), axis=1 + d, mode='wrap') + x = x.take(range(begin, end), axis=1 + d, mode='wrap') if scale_factor > 1: - x = np.expand_dims(x,axis=0) - x = torch.from_numpy(x) + x = torch.from_numpy(x).unsqueeze(0) x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear') - x = x[0].numpy() + x = x.numpy().squeeze(0) # remove buffer for d, (N, (p0, p1)) in enumerate(zip(crop, pad)): - begin, stop = scale_factor, N + p0 + p1 - scale_factor - x = x.take(range(begin, stop), axis=1 + d) + begin, end = scale_factor, N + p0 + p1 - scale_factor + x = x.take(range(begin, end), axis=1 + d) new_fields.append(x)