import torch class LossBase(torch.nn.Module): def forward(self, pred, target, style=None): return torch.nn.MSELoss()(pred, target) class LossGravPot(LossBase): def __init__(self, D1_scaling:float=0., chi_MSE_coeff:float=1., grad_MSE_coeff:float=0.,): super(LossGravPot, self).__init__() self.D1_scaling = D1_scaling self.chi_MSE_coeff = chi_MSE_coeff self.grad_MSE_coeff = grad_MSE_coeff def forward(self, pred, target, style=None): """ Loss function for the gravitational potential. loss = (1 + D1_scaling * D1) * (chi_MSE_coeff * chi_MSE + grad_MSE_coeff * grad_MSE) where: - chi_MSE is the mean squared error of the gravitational potential residual chi - grad_MSE is the mean squared error of the gradient of the gravitational potential residual - D1 is the first order linear growth factor, style[:,0] """ D1 = style[:,0] if style is not None else 0. chi_MSE = torch.nn.MSELoss()(pred, target) if self.grad_MSE_coeff <= 0: return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE) pred_grads = torch.gradient(pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width) target_grads = torch.gradient(target, dim=[2,3,4]) grad_MSE = torch.nn.MSELoss()(torch.stack(pred_grads, dim=2), torch.stack(target_grads, dim=2)) return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE)