new loss fixed
This commit is contained in:
parent
d2e4453958
commit
d39cd7b1da
1 changed files with 8 additions and 6 deletions
|
@ -29,16 +29,18 @@ class LossGravPot(LossBase):
|
|||
- D1 is the first order linear growth factor, style[:,0]
|
||||
"""
|
||||
|
||||
D1 = style[:,0] if style is not None else 0.
|
||||
D1 = style[:, [0], None, None, None] if style is not None else 0.
|
||||
|
||||
chi_MSE = torch.nn.MSELoss()(pred, target)
|
||||
scale = (1 + self.D1_scaling * D1)
|
||||
|
||||
chi_MSE = torch.nn.MSELoss()(scale*pred, scale*target)
|
||||
|
||||
if self.grad_MSE_coeff <= 0:
|
||||
return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE)
|
||||
return self.chi_MSE_coeff * chi_MSE
|
||||
|
||||
pred_grads = torch.gradient(pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width)
|
||||
target_grads = torch.gradient(target, dim=[2,3,4])
|
||||
pred_grads = torch.gradient(scale*pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width)
|
||||
target_grads = torch.gradient(scale*target, dim=[2,3,4])
|
||||
|
||||
grad_MSE = torch.nn.MSELoss()(torch.stack(pred_grads, dim=2), torch.stack(target_grads, dim=2))
|
||||
|
||||
return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE)
|
||||
return self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue