Add Eulerian shape logging

This commit is contained in:
Yin Li 2021-05-18 14:16:40 -04:00
parent 6b5530256b
commit 3d271e9c44

View File

@ -266,10 +266,12 @@ def train(epoch, loader, model, criterion,
input = resample(input, model.module.scale_factor, narrow=False) input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target) input, output, target = narrow_cast(input, output, target)
if batch <= 5 and rank == 0: if batch <= 5 and rank == 0:
print('narrowed shape :', output.shape, flush=True) print('narrowed shape :', output.shape)
lag_out, lag_tgt = output, target lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs)
if batch <= 5 and rank == 0:
print('Eulerian shape :', eul_out.shape, flush=True)
lag_loss = criterion(lag_out, lag_tgt) lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt) eul_loss = criterion(eul_out, eul_tgt)