From 70fcad44f5bd0a2919fdf2ca9781045b822b7f0f Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Tue, 1 Jul 2025 15:14:30 +0200 Subject: [PATCH] added custom loss --- sCOCA_ML/train/losses_gravpot.py | 44 ++++++++++++++++++++++++++++++++ sCOCA_ML/train/train_gravpot.py | 11 +++++--- 2 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 sCOCA_ML/train/losses_gravpot.py diff --git a/sCOCA_ML/train/losses_gravpot.py b/sCOCA_ML/train/losses_gravpot.py new file mode 100644 index 0000000..93ccb25 --- /dev/null +++ b/sCOCA_ML/train/losses_gravpot.py @@ -0,0 +1,44 @@ +import torch + +class LossBase(torch.nn.Module): + + def forward(self, pred, target, style=None): + return torch.nn.MSELoss()(pred, target) + + +class LossGravPot(LossBase): + + def __init__(self, + D1_scaling:float=0., + chi_MSE_coeff:float=1., + grad_MSE_coeff:float=0.,): + + super(LossGravPot, self).__init__() + self.D1_scaling = D1_scaling + self.chi_MSE_coeff = chi_MSE_coeff + self.grad_MSE_coeff = grad_MSE_coeff + + + def forward(self, pred, target, style=None): + """ + Loss function for the gravitational potential. + loss = (1 + D1_scaling * D1) * (chi_MSE_coeff * chi_MSE + grad_MSE_coeff * grad_MSE) + where: + - chi_MSE is the mean squared error of the gravitational potential residual chi + - grad_MSE is the mean squared error of the gradient of the gravitational potential residual + - D1 is the first order linear growth factor, style[:,0] + """ + + D1 = style[:,0] if style is not None else 0. + + chi_MSE = torch.nn.MSELoss()(pred, target) + + if self.grad_MSE_coeff <= 0: + return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE) + + pred_grads = torch.gradient(pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width) + target_grads = torch.gradient(target, dim=[2,3,4]) + + grad_MSE = torch.nn.MSELoss()(torch.stack(pred_grads, dim=2), torch.stack(target_grads, dim=2)) + + return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE) diff --git a/sCOCA_ML/train/train_gravpot.py b/sCOCA_ML/train/train_gravpot.py index 6bd84e0..3132b0d 100644 --- a/sCOCA_ML/train/train_gravpot.py +++ b/sCOCA_ML/train/train_gravpot.py @@ -14,7 +14,8 @@ def train_model(model, save_model_path=None, scheduler=None, target_crop:int = None, - epoch_start:int = 0): + epoch_start:int = 0, + loss_fn=None): """ Train a model with the given dataloader and optimizer. @@ -38,7 +39,9 @@ def train_model(model, if scheduler is None: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5) model.to(device) - loss_fn = torch.nn.MSELoss() + if loss_fn is None: + from .losses_gravpot import LossBase + loss_fn = LossBase() train_loss_log = [] val_loss_log = [] @@ -70,7 +73,7 @@ def train_model(model, if target_crop: target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop] - loss = loss_fn(output, target) + loss = loss_fn(output, target, style=style) forward_time += time.time() - t1 # Backward pass and optimization @@ -144,7 +147,7 @@ def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None): target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop] output = model(input, style) - loss = loss_fn(output, target) + loss = loss_fn(output, target, style=style) losses.append(loss.item()) styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense progress_bar.set_postfix(loss=f"{loss.item():2.5f}")