Fix crop and pad are for the shapes after perm()
not before it! This only matters for non-cubic crop and pad
This commit is contained in:
parent
08c805e47a
commit
795cefe38e
@ -162,9 +162,23 @@ class FieldDataset(Dataset):
|
|||||||
if shift is not None:
|
if shift is not None:
|
||||||
anchor[d] += torch.randint(int(shift), (1,))
|
anchor[d] += torch.randint(int(shift), (1,))
|
||||||
|
|
||||||
crop(in_fields, anchor, self.crop, self.in_pad)
|
# crop and pad are for the shapes after perm()
|
||||||
|
# so before that they themselves need perm() in the opposite ways
|
||||||
|
if self.augment:
|
||||||
|
# let i and j index axes before and after perm()
|
||||||
|
# then perm_axes is i_j, whose argsort is j_i
|
||||||
|
# the latter is needed to index crop and pad for opposite perm()
|
||||||
|
perm_axes = perm([], None, self.ndim)
|
||||||
|
argsort_perm_axes = np.argsort(perm_axes.numpy())
|
||||||
|
else:
|
||||||
|
argsort_perm_axes = slice(None)
|
||||||
|
|
||||||
|
crop(in_fields, anchor,
|
||||||
|
self.crop[argsort_perm_axes],
|
||||||
|
self.in_pad[argsort_perm_axes])
|
||||||
crop(tgt_fields, anchor * self.scale_factor,
|
crop(tgt_fields, anchor * self.scale_factor,
|
||||||
self.crop * self.scale_factor, self.tgt_pad)
|
self.crop[argsort_perm_axes] * self.scale_factor,
|
||||||
|
self.tgt_pad[argsort_perm_axes])
|
||||||
|
|
||||||
style = torch.from_numpy(style).to(torch.float32)
|
style = torch.from_numpy(style).to(torch.float32)
|
||||||
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
in_fields = [torch.from_numpy(f).to(torch.float32) for f in in_fields]
|
||||||
@ -183,7 +197,7 @@ class FieldDataset(Dataset):
|
|||||||
flip_axes = flip(in_fields, None, self.ndim)
|
flip_axes = flip(in_fields, None, self.ndim)
|
||||||
flip_axes = flip(tgt_fields, flip_axes, self.ndim)
|
flip_axes = flip(tgt_fields, flip_axes, self.ndim)
|
||||||
|
|
||||||
perm_axes = perm(in_fields, None, self.ndim)
|
perm_axes = perm(in_fields, perm_axes, self.ndim)
|
||||||
perm_axes = perm(tgt_fields, perm_axes, self.ndim)
|
perm_axes = perm(tgt_fields, perm_axes, self.ndim)
|
||||||
|
|
||||||
if self.aug_add is not None:
|
if self.aug_add is not None:
|
||||||
|
@ -252,7 +252,8 @@ def train(epoch, loader, model, criterion,
|
|||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
if batch == 1 and rank == 0:
|
if batch <= 5 and rank == 0:
|
||||||
|
print('##### batch :', batch)
|
||||||
print('style shape :', style.shape)
|
print('style shape :', style.shape)
|
||||||
print('input shape :', input.shape)
|
print('input shape :', input.shape)
|
||||||
print('output shape :', output.shape)
|
print('output shape :', output.shape)
|
||||||
@ -262,7 +263,7 @@ def train(epoch, loader, model, criterion,
|
|||||||
and model.module.scale_factor != 1):
|
and model.module.scale_factor != 1):
|
||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
if batch == 1 and rank == 0:
|
if batch <= 5 and rank == 0:
|
||||||
print('narrowed shape :', output.shape, flush=True)
|
print('narrowed shape :', output.shape, flush=True)
|
||||||
|
|
||||||
lag_out, lag_tgt = output, target
|
lag_out, lag_tgt = output, target
|
||||||
|
Loading…
Reference in New Issue
Block a user