fix crop function with upsample

This commit is contained in:
yueyingn 2020-01-23 11:34:20 -05:00 committed by GitHub
parent 581ad5d60b
commit d5b76b3725
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -149,22 +149,23 @@ def crop(fields, start, crop, pad, scale_factor=1):
new_fields = [] new_fields = []
for x in fields: for x in fields:
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)): for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
start, stop = i - p0, i + N + p1 begin, stop = i - p0, i + N + p1
if scale_factor > 1: # add buffer for linear interpolation if scale_factor > 1: # add buffer for linear interpolation
start, stop = start - 1, stop + 1 begin, stop = begin - 1, stop + 1
x = x.take(range(start, stop), axis=1 + d, mode='wrap') x = x.take(range(begin, stop), axis=1 + d, mode='wrap')
if scale_factor > 1: if scale_factor > 1:
x = np.expand_dims(x,axis=0)
x = torch.from_numpy(x) x = torch.from_numpy(x)
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear') x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
x = x.numpy() x = x[0].numpy()
# remove buffer # remove buffer
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)): for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
start, stop = scale_factor, N + p0 + p1 - scale_factor begin, stop = scale_factor, N + p0 + p1 - scale_factor
x = x.take(range(start, stop), axis=1 + d) x = x.take(range(begin, stop), axis=1 + d)
new_fields.append(x) new_fields.append(x)