Merge branch 'master' into memmap

This commit is contained in:
Yin Li 2020-07-14 14:52:03 -04:00
commit eba76bf90d

View File

@ -172,14 +172,14 @@ class FieldDataset(Dataset):
if shift is not None: if shift is not None:
anchor[d] += torch.randint(shift, (1,)) anchor[d] += torch.randint(shift, (1,))
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size) in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size)
tgt_fields = crop(tgt_fields, anchor * self.scale_factor, tgt_fields = crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor, self.crop * self.scale_factor,
np.zeros_like(self.pad), self.size) np.zeros_like(self.pad), self.size)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields]
if self.in_norms is not None: if self.in_norms is not None:
for norm, x in zip(self.in_norms, in_fields): for norm, x in zip(self.in_norms, in_fields):
norm(x) norm(x)
@ -216,7 +216,7 @@ def crop(fields, anchor, crop, pad, size):
for x in fields: for x in fields:
ind = [slice(None)] ind = [slice(None)]
for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):
i = torch.arange(a - p0, a + c + p1) i = np.arange(a - p0, a + c + p1)
i %= s i %= s
i = i.reshape((-1,) + (1,) * (ndim - d - 1)) i = i.reshape((-1,) + (1,) * (ndim - d - 1))
ind.append(i) ind.append(i)