From d39cd7b1dad8db3084e329c7db4ad35a45cff94d Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Thu, 3 Jul 2025 16:07:15 +0200 Subject: [PATCH] new loss fixed --- sCOCA_ML/train/losses_gravpot.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sCOCA_ML/train/losses_gravpot.py b/sCOCA_ML/train/losses_gravpot.py index 93ccb25..26d0a27 100644 --- a/sCOCA_ML/train/losses_gravpot.py +++ b/sCOCA_ML/train/losses_gravpot.py @@ -29,16 +29,18 @@ class LossGravPot(LossBase): - D1 is the first order linear growth factor, style[:,0] """ - D1 = style[:,0] if style is not None else 0. + D1 = style[:, [0], None, None, None] if style is not None else 0. - chi_MSE = torch.nn.MSELoss()(pred, target) + scale = (1 + self.D1_scaling * D1) + + chi_MSE = torch.nn.MSELoss()(scale*pred, scale*target) if self.grad_MSE_coeff <= 0: - return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE) + return self.chi_MSE_coeff * chi_MSE - pred_grads = torch.gradient(pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width) - target_grads = torch.gradient(target, dim=[2,3,4]) + pred_grads = torch.gradient(scale*pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width) + target_grads = torch.gradient(scale*target, dim=[2,3,4]) grad_MSE = torch.nn.MSELoss()(torch.stack(pred_grads, dim=2), torch.stack(target_grads, dim=2)) - return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE) + return self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE