Fix crop and pad are for the shapes after perm()
not before it! This only matters for non-cubic crop and pad
This commit is contained in:
parent
08c805e47a
commit
795cefe38e
2 changed files with 20 additions and 5 deletions
|
@ -162,9 +162,23 @@ class FieldDataset(Dataset):
|
|||
if shift is not None:
|
||||
anchor[d] += torch.randint(int(shift), (1,))
|
||||
|
||||
crop(in_fields, anchor, self.crop, self.in_pad)
|
||||
# crop and pad are for the shapes after perm()
|
||||
# so before that they themselves need perm() in the opposite ways
|
||||
if self.augment:
|
||||
# let i and j index axes before and after perm()
|
||||
# then perm_axes is i_j, whose argsort is j_i
|
||||
# the latter is needed to index crop and pad for opposite perm()
|
||||
perm_axes = perm([], None, self.ndim)
|
||||
argsort_perm_axes = np.argsort(perm_axes.numpy())
|
||||
else:
|
||||
argsort_perm_axes = slice(None)
|
||||
|
||||
crop(in_fields, anchor,
|
||||
self.crop[argsort_perm_axes],
|
||||
self.in_pad[argsort_perm_axes])
|
||||
crop(tgt_fields, anchor * self.scale_factor,
|
||||
self.crop * self.scale_factor, self.tgt_pad)
|
||||
self.crop[argsort_perm_axes] * self.scale_factor,
|
||||
self.tgt_pad[argsort_perm_axes])
|
||||
|
||||
style = torch.from_numpy(style).to(torch.float32)
|
||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||
|
@ -183,7 +197,7 @@ class FieldDataset(Dataset):
|
|||
flip_axes = flip(in_fields, None, self.ndim)
|
||||
flip_axes = flip(tgt_fields, flip_axes, self.ndim)
|
||||
|
||||
perm_axes = perm(in_fields, None, self.ndim)
|
||||
perm_axes = perm(in_fields, perm_axes, self.ndim)
|
||||
perm_axes = perm(tgt_fields, perm_axes, self.ndim)
|
||||
|
||||
if self.aug_add is not None:
|
||||
|
|
|
@ -252,7 +252,8 @@ def train(epoch, loader, model, criterion,
|
|||
target = target.to(device, non_blocking=True)
|
||||
|
||||
output = model(input, style)
|
||||
if batch == 1 and rank == 0:
|
||||
if batch <= 5 and rank == 0:
|
||||
print('##### batch :', batch)
|
||||
print('style shape :', style.shape)
|
||||
print('input shape :', input.shape)
|
||||
print('output shape :', output.shape)
|
||||
|
@ -262,7 +263,7 @@ def train(epoch, loader, model, criterion,
|
|||
and model.module.scale_factor != 1):
|
||||
input = resample(input, model.module.scale_factor, narrow=False)
|
||||
input, output, target = narrow_cast(input, output, target)
|
||||
if batch == 1 and rank == 0:
|
||||
if batch <= 5 and rank == 0:
|
||||
print('narrowed shape :', output.shape, flush=True)
|
||||
|
||||
lag_out, lag_tgt = output, target
|
||||
|
|
Loading…
Reference in a new issue