Replace np.take by advanced indexing for cropping
This commit is contained in:
parent
c61b6d6fab
commit
c2edeb344e
@ -172,14 +172,14 @@ class FieldDataset(Dataset):
|
|||||||
if shift is not None:
|
if shift is not None:
|
||||||
anchor[d] += torch.randint(shift, (1,))
|
anchor[d] += torch.randint(shift, (1,))
|
||||||
|
|
||||||
in_fields = crop(in_fields, anchor, self.crop, self.pad)
|
|
||||||
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
|
||||||
self.crop * self.scale_factor,
|
|
||||||
np.zeros_like(self.pad))
|
|
||||||
|
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
|
||||||
|
|
||||||
|
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
|
||||||
|
tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
|
||||||
|
self.crop * self.scale_factor,
|
||||||
|
np.zeros_like(self.pad), self.size)
|
||||||
|
|
||||||
if self.in_norms is not None:
|
if self.in_norms is not None:
|
||||||
for norm, x in zip(self.in_norms, in_fields):
|
for norm, x in zip(self.in_norms, in_fields):
|
||||||
norm(x)
|
norm(x)
|
||||||
@ -208,12 +208,20 @@ class FieldDataset(Dataset):
|
|||||||
return in_fields, tgt_fields
|
return in_fields, tgt_fields
|
||||||
|
|
||||||
|
|
||||||
def crop(fields, anchor, crop, pad):
|
def crop(fields, anchor, crop, pad, size):
|
||||||
|
ndim = len(size)
|
||||||
|
assert all(len(x) == ndim for x in [anchor, crop, pad, size]), 'inconsistent ndim'
|
||||||
|
|
||||||
new_fields = []
|
new_fields = []
|
||||||
for x in fields:
|
for x in fields:
|
||||||
for d, (a, c, (p0, p1)) in enumerate(zip(anchor, crop, pad)):
|
ind = [slice(None)]
|
||||||
begin, end = a - p0, a + c + p1
|
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
|
||||||
x = x.take(range(begin, end), axis=1 + d, mode='wrap')
|
i = torch.arange(a - p0, a + c + p1)
|
||||||
|
i %= s
|
||||||
|
i = i.reshape((-1,) + (1,) * (ndim - d - 1))
|
||||||
|
ind.append(i)
|
||||||
|
|
||||||
|
x = x[ind]
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user