new loss fixed

This commit is contained in:
Mayeul Aubin 2025-07-03 16:07:15 +02:00
parent d2e4453958
commit d39cd7b1da

View file

@ -29,16 +29,18 @@ class LossGravPot(LossBase):
- D1 is the first order linear growth factor, style[:,0] - D1 is the first order linear growth factor, style[:,0]
""" """
D1 = style[:,0] if style is not None else 0. D1 = style[:, [0], None, None, None] if style is not None else 0.
chi_MSE = torch.nn.MSELoss()(pred, target) scale = (1 + self.D1_scaling * D1)
chi_MSE = torch.nn.MSELoss()(scale*pred, scale*target)
if self.grad_MSE_coeff <= 0: if self.grad_MSE_coeff <= 0:
return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE) return self.chi_MSE_coeff * chi_MSE
pred_grads = torch.gradient(pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width) pred_grads = torch.gradient(scale*pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width)
target_grads = torch.gradient(target, dim=[2,3,4]) target_grads = torch.gradient(scale*target, dim=[2,3,4])
grad_MSE = torch.nn.MSELoss()(torch.stack(pred_grads, dim=2), torch.stack(target_grads, dim=2)) grad_MSE = torch.nn.MSELoss()(torch.stack(pred_grads, dim=2), torch.stack(target_grads, dim=2))
return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE) return self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE