diff --git a/map2map/train.py b/map2map/train.py index 2388f7a..83ffddc 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -266,10 +266,12 @@ def train(epoch, loader, model, criterion, input = resample(input, model.module.scale_factor, narrow=False) input, output, target = narrow_cast(input, output, target) if batch <= 5 and rank == 0: - print('narrowed shape :', output.shape, flush=True) + print('narrowed shape :', output.shape) lag_out, lag_tgt = output, target eul_out, eul_tgt = lag2eul([lag_out, lag_tgt], **args.misc_kwargs) + if batch <= 5 and rank == 0: + print('Eulerian shape :', eul_out.shape, flush=True) lag_loss = criterion(lag_out, lag_tgt) eul_loss = criterion(eul_out, eul_tgt)