Fix bug that F.interpolate does not apply to np arrays

This commit is contained in:
Yin Li 2020-01-23 07:22:34 -05:00
parent 84a369d4ed
commit 6938eea089

View File

@ -149,13 +149,17 @@ def crop(fields, start, crop, pad, scale_factor=1):
for x in fields: for x in fields:
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)): for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
start, stop = i - p0, i + N + p1 start, stop = i - p0, i + N + p1
# add buffer for linear interpolation
if scale_factor > 1: if scale_factor > 1: # add buffer for linear interpolation
start, stop = start - 1, stop + 1 start, stop = start - 1, stop + 1
x = x.take(range(start, stop), axis=1 + d, mode='wrap') x = x.take(range(start, stop), axis=1 + d, mode='wrap')
if scale_factor > 1: if scale_factor > 1:
x = torch.from_numpy(x)
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear') x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
x = x.numpy()
# remove buffer # remove buffer
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)): for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
start, stop = scale_factor, N + p0 + p1 - scale_factor start, stop = scale_factor, N + p0 + p1 - scale_factor