diff --git a/map2map/data/fields.py b/map2map/data/fields.py index c1b3f96..873f11e 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -172,14 +172,14 @@ class FieldDataset(Dataset): if shift is not None: anchor[d] += torch.randint(shift, (1,)) - in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] - tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] - in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size) tgt_fields = crop(tgt_fields, anchor * self.scale_factor, self.crop * self.scale_factor, np.zeros_like(self.pad), self.size) + in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] + tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] + if self.in_norms is not None: for norm, x in zip(self.in_norms, in_fields): norm(x) @@ -216,7 +216,7 @@ def crop(fields, anchor, crop, pad, size): for x in fields: ind = [slice(None)] for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): - i = torch.arange(a - p0, a + c + p1) + i = np.arange(a - p0, a + c + p1) i %= s i = i.reshape((-1,) + (1,) * (ndim - d - 1)) ind.append(i)