diff --git a/map2map/data/fields.py b/map2map/data/fields.py index d7224a8..16546ec 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -162,11 +162,9 @@ class FieldDataset(Dataset): if shift is not None: anchor[d] += torch.randint(int(shift), (1,)) - crop(in_fields, anchor, self.crop, self.in_pad, self.size) + crop(in_fields, anchor, self.crop, self.in_pad) crop(tgt_fields, anchor * self.scale_factor, - self.crop * self.scale_factor, - self.tgt_pad, - self.size * self.scale_factor) + self.crop * self.scale_factor, self.tgt_pad) style = torch.from_numpy(style).to(torch.float32) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] @@ -283,6 +281,7 @@ class FieldDataset(Dataset): def fill(field, patch, anchor): ndim = len(anchor) + assert field.ndim == patch.ndim == 1 + ndim, 'ndim mismatch' ind = [slice(None)] for d, (p, a, s) in enumerate(zip( @@ -296,9 +295,11 @@ def fill(field, patch, anchor): field[ind] = patch -def crop(fields, anchor, crop, pad, size): +def crop(fields, anchor, crop, pad): + assert all(x.shape == fields[0].shape for x in fields), 'shape mismatch' + size = fields[0].shape[1:] ndim = len(size) - assert all(len(x) == ndim for x in [anchor, crop, pad]), 'ndim mismatch' + assert ndim == len(anchor) == len(crop) == len(pad), 'ndim mismatch' ind = [slice(None)] for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)):