diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 4369afe..c610a3a 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -139,11 +139,11 @@ class FieldDataset(Dataset): if shift is not None: anchor[d] += torch.randint(int(shift), (1,)) - in_fields = crop(in_fields, anchor, self.crop, self.in_pad, self.size) - tgt_fields = crop(tgt_fields, anchor * self.scale_factor, - self.crop * self.scale_factor, - self.tgt_pad, - self.size * self.scale_factor) + crop(in_fields, anchor, self.crop, self.in_pad, self.size) + crop(tgt_fields, anchor * self.scale_factor, + self.crop * self.scale_factor, + self.tgt_pad, + self.size * self.scale_factor) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] tgt_fields = [torch.from_numpy(f).to(torch.float32) for f in tgt_fields] @@ -158,11 +158,11 @@ class FieldDataset(Dataset): norm(x) if self.augment: - in_fields, flip_axes = flip(in_fields, None, self.ndim) - tgt_fields, flip_axes = flip(tgt_fields, flip_axes, self.ndim) + flip_axes = flip(in_fields, None, self.ndim) + flip_axes = flip(tgt_fields, flip_axes, self.ndim) - in_fields, perm_axes = perm(in_fields, None, self.ndim) - tgt_fields, perm_axes = perm(tgt_fields, perm_axes, self.ndim) + perm_axes = perm(in_fields, None, self.ndim) + perm_axes = perm(tgt_fields, perm_axes, self.ndim) if self.aug_add is not None: add_fac = add(in_fields, None, self.aug_add) @@ -180,22 +180,22 @@ class FieldDataset(Dataset): def crop(fields, anchor, crop, pad, size): ndim = len(size) - assert all(len(x) == ndim for x in [anchor, crop, pad, size]), 'inconsistent ndim' + assert all(len(x) == ndim for x in [anchor, crop, pad]), 'ndim mismatch' - new_fields = [] - for x in fields: - ind = [slice(None)] - for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): - i = np.arange(a - p0, a + c + p1) - i %= s - i = i.reshape((-1,) + (1,) * (ndim - d - 1)) - ind.append(i) + ind = [slice(None)] + for d, (a, c, (p0, p1), s) in enumerate(zip(anchor, crop, pad, size)): + i = np.arange(a - p0, a + c + p1) + i %= s + i = i.reshape((-1,) + (1,) * (ndim - d - 1)) + ind.append(i) + ind = tuple(ind) - x = x[tuple(ind)] + for i, x in enumerate(fields): + x = x[ind] - new_fields.append(x) + fields[i] = x - return new_fields + return ind def flip(fields, axes, ndim): @@ -205,17 +205,16 @@ def flip(fields, axes, ndim): axes = torch.randint(2, (ndim,), dtype=torch.bool) axes = torch.arange(ndim)[axes] - new_fields = [] - for x in fields: + for i, x in enumerate(fields): if x.shape[0] == ndim: # flip vector components x[axes] = - x[axes] shifted_axes = (1 + axes).tolist() x = torch.flip(x, shifted_axes) - new_fields.append(x) + fields[i] = x - return new_fields, axes + return axes def perm(fields, axes, ndim): @@ -224,17 +223,16 @@ def perm(fields, axes, ndim): if axes is None: axes = torch.randperm(ndim) - new_fields = [] - for x in fields: + for i, x in enumerate(fields): if x.shape[0] == ndim: # permutate vector components x = x[axes] shifted_axes = [0] + (1 + axes).tolist() x = x.permute(shifted_axes) - new_fields.append(x) + fields[i] = x - return new_fields, axes + return axes def add(fields, fac, std):