Add shape report after narrow_cast in training
This commit is contained in:
parent
01cc1b6964
commit
5d22594ede
@ -239,19 +239,23 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(3, dtype=torch.float64, device=device)
|
||||||
|
|
||||||
for i, (input, target) in enumerate(loader):
|
for i, (input, target) in enumerate(loader):
|
||||||
|
batch = epoch * len(loader) + i + 1
|
||||||
|
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if epoch == 0 and i == 0 and rank == 0:
|
if batch == 1 and rank == 0:
|
||||||
print('input.shape =', input.shape)
|
print('input shape :', input.shape)
|
||||||
print('output.shape =', output.shape)
|
print('output shape :', output.shape)
|
||||||
print('target.shape =', target.shape, flush=True)
|
print('target shape :', target.shape)
|
||||||
|
|
||||||
if (hasattr(model.module, 'scale_factor')
|
if (hasattr(model.module, 'scale_factor')
|
||||||
and model.module.scale_factor != 1):
|
and model.module.scale_factor != 1):
|
||||||
input = resample(input, model.module.scale_factor, narrow=False)
|
input = resample(input, model.module.scale_factor, narrow=False)
|
||||||
input, output, target = narrow_cast(input, output, target)
|
input, output, target = narrow_cast(input, output, target)
|
||||||
|
if batch == 1 and rank == 0:
|
||||||
|
print('narrowed shape :', output.shape, flush=True)
|
||||||
|
|
||||||
lag_out, lag_tgt = output, target
|
lag_out, lag_tgt = output, target
|
||||||
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
eul_out, eul_tgt = lag2eul(lag_out, lag_tgt)
|
||||||
@ -268,7 +272,6 @@ def train(epoch, loader, model, lag2eul, criterion,
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
grads = get_grads(model)
|
grads = get_grads(model)
|
||||||
|
|
||||||
batch = epoch * len(loader) + i + 1
|
|
||||||
if batch % args.log_interval == 0:
|
if batch % args.log_interval == 0:
|
||||||
dist.all_reduce(lag_loss)
|
dist.all_reduce(lag_loss)
|
||||||
dist.all_reduce(eul_loss)
|
dist.all_reduce(eul_loss)
|
||||||
|
Loading…
Reference in New Issue
Block a user