added custom loss
This commit is contained in:
parent
8058d81f26
commit
70fcad44f5
2 changed files with 51 additions and 4 deletions
44
sCOCA_ML/train/losses_gravpot.py
Normal file
44
sCOCA_ML/train/losses_gravpot.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class LossBase(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, pred, target, style=None):
|
||||||
|
return torch.nn.MSELoss()(pred, target)
|
||||||
|
|
||||||
|
|
||||||
|
class LossGravPot(LossBase):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
D1_scaling:float=0.,
|
||||||
|
chi_MSE_coeff:float=1.,
|
||||||
|
grad_MSE_coeff:float=0.,):
|
||||||
|
|
||||||
|
super(LossGravPot, self).__init__()
|
||||||
|
self.D1_scaling = D1_scaling
|
||||||
|
self.chi_MSE_coeff = chi_MSE_coeff
|
||||||
|
self.grad_MSE_coeff = grad_MSE_coeff
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, pred, target, style=None):
|
||||||
|
"""
|
||||||
|
Loss function for the gravitational potential.
|
||||||
|
loss = (1 + D1_scaling * D1) * (chi_MSE_coeff * chi_MSE + grad_MSE_coeff * grad_MSE)
|
||||||
|
where:
|
||||||
|
- chi_MSE is the mean squared error of the gravitational potential residual chi
|
||||||
|
- grad_MSE is the mean squared error of the gradient of the gravitational potential residual
|
||||||
|
- D1 is the first order linear growth factor, style[:,0]
|
||||||
|
"""
|
||||||
|
|
||||||
|
D1 = style[:,0] if style is not None else 0.
|
||||||
|
|
||||||
|
chi_MSE = torch.nn.MSELoss()(pred, target)
|
||||||
|
|
||||||
|
if self.grad_MSE_coeff <= 0:
|
||||||
|
return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE)
|
||||||
|
|
||||||
|
pred_grads = torch.gradient(pred, dim=[2,3,4]) # Assuming pred is a 5D tensor (batch, channel, depth, height, width)
|
||||||
|
target_grads = torch.gradient(target, dim=[2,3,4])
|
||||||
|
|
||||||
|
grad_MSE = torch.nn.MSELoss()(torch.stack(pred_grads, dim=2), torch.stack(target_grads, dim=2))
|
||||||
|
|
||||||
|
return (1 + self.D1_scaling * D1) * (self.chi_MSE_coeff * chi_MSE + self.grad_MSE_coeff * grad_MSE)
|
|
@ -14,7 +14,8 @@ def train_model(model,
|
||||||
save_model_path=None,
|
save_model_path=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
target_crop:int = None,
|
target_crop:int = None,
|
||||||
epoch_start:int = 0):
|
epoch_start:int = 0,
|
||||||
|
loss_fn=None):
|
||||||
"""
|
"""
|
||||||
Train a model with the given dataloader and optimizer.
|
Train a model with the given dataloader and optimizer.
|
||||||
|
|
||||||
|
@ -38,7 +39,9 @@ def train_model(model,
|
||||||
if scheduler is None:
|
if scheduler is None:
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
loss_fn = torch.nn.MSELoss()
|
if loss_fn is None:
|
||||||
|
from .losses_gravpot import LossBase
|
||||||
|
loss_fn = LossBase()
|
||||||
train_loss_log = []
|
train_loss_log = []
|
||||||
val_loss_log = []
|
val_loss_log = []
|
||||||
|
|
||||||
|
@ -70,7 +73,7 @@ def train_model(model,
|
||||||
|
|
||||||
if target_crop:
|
if target_crop:
|
||||||
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target, style=style)
|
||||||
forward_time += time.time() - t1
|
forward_time += time.time() - t1
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
|
@ -144,7 +147,7 @@ def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None):
|
||||||
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
||||||
|
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target, style=style)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense
|
styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense
|
||||||
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue