Remove many item() to avoid unnecessary cpu-gpu sync
This commit is contained in:
parent
5e4633b125
commit
2717269a8d
1 changed files with 7 additions and 7 deletions
|
@ -272,9 +272,9 @@ def train(epoch, loader, model, criterion,
|
|||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
loss = lag_loss * eul_loss
|
||||
epoch_loss[0] += lag_loss.item()
|
||||
epoch_loss[1] += eul_loss.item()
|
||||
epoch_loss[2] += loss.item()
|
||||
epoch_loss[0] += lag_loss.detach()
|
||||
epoch_loss[1] += eul_loss.detach()
|
||||
epoch_loss[2] += loss.detach()
|
||||
|
||||
optimizer.zero_grad()
|
||||
torch.log(loss).backward() # NOTE actual loss is log(loss)
|
||||
|
@ -359,9 +359,9 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||
lag_loss = criterion(lag_out, lag_tgt)
|
||||
eul_loss = criterion(eul_out, eul_tgt)
|
||||
loss = lag_loss * eul_loss
|
||||
epoch_loss[0] += lag_loss.item()
|
||||
epoch_loss[1] += eul_loss.item()
|
||||
epoch_loss[2] += loss.item()
|
||||
epoch_loss[0] += lag_loss.detach()
|
||||
epoch_loss[1] += eul_loss.detach()
|
||||
epoch_loss[2] += loss.detach()
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * world_size
|
||||
|
@ -452,5 +452,5 @@ def get_grads(model):
|
|||
grads = list(p.grad for n, p in model.named_parameters()
|
||||
if '.weight' in n)
|
||||
grads = [grads[0], grads[-1]]
|
||||
grads = [g.detach().norm().item() for g in grads]
|
||||
grads = [g.detach().norm() for g in grads]
|
||||
return grads
|
||||
|
|
Loading…
Reference in a new issue