Add shape report after narrow_cast in training

This commit is contained in:
Yin Li 2020-08-22 12:28:08 -07:00
parent 01cc1b6964
commit 5d22594ede

View file

@ -239,19 +239,23 @@ def train(epoch, loader, model, lag2eul, criterion,
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
for i, (input, target) in enumerate(loader):
batch = epoch * len(loader) + i + 1
input = input.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output = model(input)
if epoch == 0 and i == 0 and rank == 0:
print('input.shape =', input.shape)
print('output.shape =', output.shape)
print('target.shape =', target.shape, flush=True)
if batch == 1 and rank == 0:
print('input shape :', input.shape)
print('output shape :', output.shape)
print('target shape :', target.shape)
if (hasattr(model.module, 'scale_factor')
and model.module.scale_factor != 1):
input = resample(input, model.module.scale_factor, narrow=False)
input, output, target = narrow_cast(input, output, target)
if batch == 1 and rank == 0:
print('narrowed shape :', output.shape, flush=True)
lag_out, lag_tgt = output, target
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
@ -268,7 +272,6 @@ def train(epoch, loader, model, lag2eul, criterion,
optimizer.step()
grads = get_grads(model)
batch = epoch * len(loader) + i + 1
if batch % args.log_interval == 0:
dist.all_reduce(lag_loss)
dist.all_reduce(eul_loss)