44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
import torch
|
|
|
|
class LossBase(torch.nn.Module):
|
|
|
|
def forward(self, pred, target, style=None):
|
|
return torch.nn.MSELoss()(pred, target)
|
|
|
|
|
|
class LossGravPot(LossBase):
|
|
|
|
def __init__(self,
|
|
D1_scaling:float=0.,
|
|
chi_MSE_coeff:float=1.,
|
|
grad_MSE_coeff:float=0.,):
|
|
|
|
super(LossGravPot, self).__init__()
|
|
self.D1_scaling = D1_scaling
|
|
self.chi_MSE_coeff = chi_MSE_coeff
|
|
self.grad_MSE_coeff = grad_MSE_coeff
|
|
|
|
|
|
def forward(self, pred, target, style=None):
|
|
"""
|
|
Loss function for the gravitational potential.
|
|
loss = (1 + D1_scaling * D1) * (chi_MSE_coeff * chi_MSE + grad_MSE_coeff * grad_MSE)
|
|
where:
|
|
- chi_MSE is the mean squared error of the gravitational potential residual chi
|
|
- grad_MSE is the mean squared error of the gradient of the gravitational potential residual
|
|
- D1 is the first order linear growth factor, style[:,0]
|
|
"""
|
|
|
|
D1 = style[:,0] if style is not None else 0.
|
|
|
|
chi_MSE = torch.nn.MSELoss()(pred, target)
|
|
|
|
if self.grad_MSE_coeff <= 0:
|
|
return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE)
|
|
|
|
pred_grads = torch.gradient(pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width)
|
|
target_grads = torch.gradient(target, dim=[2,3,4])
|
|
|
|
grad_MSE = torch.nn.MSELoss()(torch.stack(pred_grads, dim=2), torch.stack(target_grads, dim=2))
|
|
|
|
return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE)
|