Fix crop and pad are for the shapes after perm()

not before it!
This only matters for non-cubic crop and pad
This commit is contained in:
Yin Li 2021-04-12 16:53:57 -04:00
parent 08c805e47a
commit 795cefe38e
2 changed files with 20 additions and 5 deletions

View file

@ -162,9 +162,23 @@ class FieldDataset(Dataset):
if shift is not None:
anchor[d] += torch.randint(int(shift), (1,))
crop(in_fields, anchor, self.crop, self.in_pad)
# crop and pad are for the shapes after perm()
# so before that they themselves need perm() in the opposite ways
if self.augment:
# let i and j index axes before and after perm()
# then perm_axes is i_j, whose argsort is j_i
# the latter is needed to index crop and pad for opposite perm()
perm_axes = perm([], None, self.ndim)
argsort_perm_axes = np.argsort(perm_axes.numpy())
else:
argsort_perm_axes = slice(None)
crop(in_fields, anchor,
self.crop[argsort_perm_axes],
self.in_pad[argsort_perm_axes])
crop(tgt_fields, anchor * self.scale_factor,
self.crop * self.scale_factor, self.tgt_pad)
self.crop[argsort_perm_axes] * self.scale_factor,
self.tgt_pad[argsort_perm_axes])
style = torch.from_numpy(style).to(torch.float32)
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
@ -183,7 +197,7 @@ class FieldDataset(Dataset):
flip_axes = flip(in_fields, None, self.ndim)
flip_axes = flip(tgt_fields, flip_axes, self.ndim)
perm_axes = perm(in_fields, None, self.ndim)
perm_axes = perm(in_fields, perm_axes, self.ndim)
perm_axes = perm(tgt_fields, perm_axes, self.ndim)
if self.aug_add is not None:

View file

@ -252,7 +252,8 @@ def train(epoch, loader, model, criterion,
target = target.to(device, non_blocking=True)
output = model(input, style)
if batch == 1 and rank == 0:
if batch <= 5 and rank == 0:
print('##### batch :', batch)
print('style shape :', style.shape)
print('input shape :', input.shape)
print('output shape :', output.shape)
@ -262,7 +263,7 @@ def train(epoch, loader, model, criterion,
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
if batch == 1 and rank == 0:
if batch <= 5 and rank == 0:
print('narrowed shape :', output.shape, flush=True)
lag_out, lag_tgt = output, target