Remove many item() to avoid unnecessary cpu-gpu sync

This commit is contained in:
Yin Li 2021-04-22 22:39:01 -04:00
parent 5e4633b125
commit 2717269a8d

View File

@ -272,9 +272,9 @@ def train(epoch, loader, model, criterion,
lag_loss = criterion(lag_out, lag_tgt) lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt) eul_loss = criterion(eul_out, eul_tgt)
loss = lag_loss * eul_loss loss = lag_loss * eul_loss
epoch_loss[0] += lag_loss.item() epoch_loss[0] += lag_loss.detach()
epoch_loss[1] += eul_loss.item() epoch_loss[1] += eul_loss.detach()
epoch_loss[2] += loss.item() epoch_loss[2] += loss.detach()
optimizer.zero_grad() optimizer.zero_grad()
torch.log(loss).backward() # NOTE actual loss is log(loss) torch.log(loss).backward() # NOTE actual loss is log(loss)
@ -359,9 +359,9 @@ def validate(epoch, loader, model, criterion, logger, device, args):
lag_loss = criterion(lag_out, lag_tgt) lag_loss = criterion(lag_out, lag_tgt)
eul_loss = criterion(eul_out, eul_tgt) eul_loss = criterion(eul_out, eul_tgt)
loss = lag_loss * eul_loss loss = lag_loss * eul_loss
epoch_loss[0] += lag_loss.item() epoch_loss[0] += lag_loss.detach()
epoch_loss[1] += eul_loss.item() epoch_loss[1] += eul_loss.detach()
epoch_loss[2] += loss.item() epoch_loss[2] += loss.detach()
dist.all_reduce(epoch_loss) dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * world_size epoch_loss /= len(loader) * world_size
@ -452,5 +452,5 @@ def get_grads(model):
grads = list(p.grad for n, p in model.named_parameters() grads = list(p.grad for n, p in model.named_parameters()
if '.weight' in n) if '.weight' in n)
grads = [grads[0], grads[-1]] grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads] grads = [g.detach().norm() for g in grads]
return grads return grads