Fix bug that F.interpolate does not apply to np arrays
This commit is contained in:
parent
84a369d4ed
commit
6938eea089
@ -149,13 +149,17 @@ def crop(fields, start, crop, pad, scale_factor=1):
|
|||||||
for x in fields:
|
for x in fields:
|
||||||
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
for d, (i, N, (p0, p1)) in enumerate(zip(start, crop, pad)):
|
||||||
start, stop = i - p0, i + N + p1
|
start, stop = i - p0, i + N + p1
|
||||||
# add buffer for linear interpolation
|
|
||||||
if scale_factor > 1:
|
if scale_factor > 1: # add buffer for linear interpolation
|
||||||
start, stop = start - 1, stop + 1
|
start, stop = start - 1, stop + 1
|
||||||
|
|
||||||
x = x.take(range(start, stop), axis=1 + d, mode='wrap')
|
x = x.take(range(start, stop), axis=1 + d, mode='wrap')
|
||||||
|
|
||||||
if scale_factor > 1:
|
if scale_factor > 1:
|
||||||
|
x = torch.from_numpy(x)
|
||||||
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
|
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
|
||||||
|
x = x.numpy()
|
||||||
|
|
||||||
# remove buffer
|
# remove buffer
|
||||||
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
||||||
start, stop = scale_factor, N + p0 + p1 - scale_factor
|
start, stop = scale_factor, N + p0 + p1 - scale_factor
|
||||||
|
Loading…
Reference in New Issue
Block a user