diff --git a/map2map/train.py b/map2map/train.py index 6163240..9dac857 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -272,9 +272,9 @@ def train(epoch, loader, model, criterion, lag_loss = criterion(lag_out, lag_tgt) eul_loss = criterion(eul_out, eul_tgt) loss = lag_loss * eul_loss - epoch_loss[0] += lag_loss.item() - epoch_loss[1] += eul_loss.item() - epoch_loss[2] += loss.item() + epoch_loss[0] += lag_loss.detach() + epoch_loss[1] += eul_loss.detach() + epoch_loss[2] += loss.detach() optimizer.zero_grad() torch.log(loss).backward() # NOTE actual loss is log(loss) @@ -359,9 +359,9 @@ def validate(epoch, loader, model, criterion, logger, device, args): lag_loss = criterion(lag_out, lag_tgt) eul_loss = criterion(eul_out, eul_tgt) loss = lag_loss * eul_loss - epoch_loss[0] += lag_loss.item() - epoch_loss[1] += eul_loss.item() - epoch_loss[2] += loss.item() + epoch_loss[0] += lag_loss.detach() + epoch_loss[1] += eul_loss.detach() + epoch_loss[2] += loss.detach() dist.all_reduce(epoch_loss) epoch_loss /= len(loader) * world_size @@ -452,5 +452,5 @@ def get_grads(model): grads = list(p.grad for n, p in model.named_parameters() if '.weight' in n) grads = [grads[0], grads[-1]] - grads = [g.detach().norm().item() for g in grads] + grads = [g.detach().norm() for g in grads] return grads