diff --git a/map2map/data/fields.py b/map2map/data/fields.py index e36dc99..873f11e 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -172,10 +172,10 @@ class FieldDataset(Dataset): if shift is not None: anchor[d] += torch.randint(shift, (1,)) - in_fields = crop(in_fields, anchor, self.crop, self.pad) + in_fields = crop(in_fields, anchor, self.crop, self.pad, self.size) tgt_fields = crop(tgt_fields, anchor * self.scale_factor, self.crop * self.scale_factor, - np.zeros_like(self.pad)) + np.zeros_like(self.pad), self.size) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] @@ -208,12 +208,20 @@ class FieldDataset(Dataset): return in_fields, tgt_fields -def crop(fields, anchor, crop, pad): +def crop(fields, anchor, crop, pad, size): + ndim = len(size) + assert all(len(x) == ndim for x in [anchor, crop, pad, size]), 'inconsistent ndim' + new_fields = [] for x in fields: - for d, (a, c, (p0, p1)) in enumerate(zip(anchor, crop, pad)): - begin, end = a - p0, a + c + p1 - x = x.take(range(begin, end), axis=1 + d, mode='wrap') + ind = [slice(None)] + for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): + i = np.arange(a - p0, a + c + p1) + i %= s + i = i.reshape((-1,) + (1,) * (ndim - d - 1)) + ind.append(i) + + x = x[ind] new_fields.append(x)