diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 16546ec..62ea348 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -162,9 +162,23 @@ class FieldDataset(Dataset): if shift is not None: anchor[d] += torch.randint(int(shift), (1,)) - crop(in_fields, anchor, self.crop, self.in_pad) + # crop and pad are for the shapes after perm() + # so before that they themselves need perm() in the opposite ways + if self.augment: + # let i and j index axes before and after perm() + # then perm_axes is i_j, whose argsort is j_i + # the latter is needed to index crop and pad for opposite perm() + perm_axes = perm([], None, self.ndim) + argsort_perm_axes = np.argsort(perm_axes.numpy()) + else: + argsort_perm_axes = slice(None) + + crop(in_fields, anchor, + self.crop[argsort_perm_axes], + self.in_pad[argsort_perm_axes]) crop(tgt_fields, anchor * self.scale_factor, - self.crop * self.scale_factor, self.tgt_pad) + self.crop[argsort_perm_axes] * self.scale_factor, + self.tgt_pad[argsort_perm_axes]) style = torch.from_numpy(style).to(torch.float32) in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields] @@ -183,7 +197,7 @@ class FieldDataset(Dataset): flip_axes = flip(in_fields, None, self.ndim) flip_axes = flip(tgt_fields, flip_axes, self.ndim) - perm_axes = perm(in_fields, None, self.ndim) + perm_axes = perm(in_fields, perm_axes, self.ndim) perm_axes = perm(tgt_fields, perm_axes, self.ndim) if self.aug_add is not None: diff --git a/map2map/train.py b/map2map/train.py index 901a84a..6163240 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -252,7 +252,8 @@ def train(epoch, loader, model, criterion, target = target.to(device, non_blocking=True) output = model(input, style) - if batch == 1 and rank == 0: + if batch <= 5 and rank == 0: + print('##### batch :', batch) print('style shape :', style.shape) print('input shape :', input.shape) print('output shape :', output.shape) @@ -262,7 +263,7 @@ def train(epoch, loader, model, criterion, and model.module.scale_factor != 1): input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) - if batch == 1 and rank == 0: + if batch <= 5 and rank == 0: print('narrowed shape :', output.shape, flush=True) lag_out, lag_tgt = output, target