Add Eulerian shape logging
This commit is contained in:
parent
6b5530256b
commit
3d271e9c44
@ -266,10 +266,12 @@ def train(epoch, loader, model, criterion,
|
|||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
if batch <= 5 and rank == 0:
|
if batch <= 5 and rank == 0:
|
||||||
print('narrowed shape :', output.shape, flush=True)
|
print('narrowed shape :', output.shape)
|
||||||
|
|
||||||
lag_out, lag_tgt = output, target
|
lag_out, lag_tgt = output, target
|
||||||
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
|
||||||
|
if batch <= 5 and rank == 0:
|
||||||
|
print('Eulerian shape :', eul_out.shape, flush=True)
|
||||||
|
|
||||||
lag_loss = criterion(lag_out, lag_tgt)
|
lag_loss = criterion(lag_out, lag_tgt)
|
||||||
eul_loss = criterion(eul_out, eul_tgt)
|
eul_loss = criterion(eul_out, eul_tgt)
|
||||||
|
Loading…
Reference in New Issue
Block a user