Remove many item() to avoid unnecessary cpu-gpu sync
This commit is contained in:
parent
5e4633b125
commit
2717269a8d
@ -272,9 +272,9 @@ def train(epoch, loader, model, criterion,
|
|||||||
lag_loss = criterion(lag_out, lag_tgt)
|
lag_loss = criterion(lag_out, lag_tgt)
|
||||||
eul_loss = criterion(eul_out, eul_tgt)
|
eul_loss = criterion(eul_out, eul_tgt)
|
||||||
loss = lag_loss * eul_loss
|
loss = lag_loss * eul_loss
|
||||||
epoch_loss[0] += lag_loss.item()
|
epoch_loss[0] += lag_loss.detach()
|
||||||
epoch_loss[1] += eul_loss.item()
|
epoch_loss[1] += eul_loss.detach()
|
||||||
epoch_loss[2] += loss.item()
|
epoch_loss[2] += loss.detach()
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
torch.log(loss).backward() # NOTE actual loss is log(loss)
|
torch.log(loss).backward() # NOTE actual loss is log(loss)
|
||||||
@ -359,9 +359,9 @@ def validate(epoch, loader, model, criterion, logger, device, args):
|
|||||||
lag_loss = criterion(lag_out, lag_tgt)
|
lag_loss = criterion(lag_out, lag_tgt)
|
||||||
eul_loss = criterion(eul_out, eul_tgt)
|
eul_loss = criterion(eul_out, eul_tgt)
|
||||||
loss = lag_loss * eul_loss
|
loss = lag_loss * eul_loss
|
||||||
epoch_loss[0] += lag_loss.item()
|
epoch_loss[0] += lag_loss.detach()
|
||||||
epoch_loss[1] += eul_loss.item()
|
epoch_loss[1] += eul_loss.detach()
|
||||||
epoch_loss[2] += loss.item()
|
epoch_loss[2] += loss.detach()
|
||||||
|
|
||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
@ -452,5 +452,5 @@ def get_grads(model):
|
|||||||
grads = list(p.grad for n, p in model.named_parameters()
|
grads = list(p.grad for n, p in model.named_parameters()
|
||||||
if '.weight' in n)
|
if '.weight' in n)
|
||||||
grads = [grads[0], grads[-1]]
|
grads = [grads[0], grads[-1]]
|
||||||
grads = [g.detach().norm().item() for g in grads]
|
grads = [g.detach().norm() for g in grads]
|
||||||
return grads
|
return grads
|
||||||
|
Loading…
Reference in New Issue
Block a user