added custom loss
This commit is contained in:
parent
8058d81f26
commit
70fcad44f5
2 changed files with 51 additions and 4 deletions
|
@ -14,7 +14,8 @@ def train_model(model,
|
|||
save_model_path=None,
|
||||
scheduler=None,
|
||||
target_crop:int = None,
|
||||
epoch_start:int = 0):
|
||||
epoch_start:int = 0,
|
||||
loss_fn=None):
|
||||
"""
|
||||
Train a model with the given dataloader and optimizer.
|
||||
|
||||
|
@ -38,7 +39,9 @@ def train_model(model,
|
|||
if scheduler is None:
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
||||
model.to(device)
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
if loss_fn is None:
|
||||
from .losses_gravpot import LossBase
|
||||
loss_fn = LossBase()
|
||||
train_loss_log = []
|
||||
val_loss_log = []
|
||||
|
||||
|
@ -70,7 +73,7 @@ def train_model(model,
|
|||
|
||||
if target_crop:
|
||||
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
||||
loss = loss_fn(output, target)
|
||||
loss = loss_fn(output, target, style=style)
|
||||
forward_time += time.time() - t1
|
||||
|
||||
# Backward pass and optimization
|
||||
|
@ -144,7 +147,7 @@ def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None):
|
|||
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
||||
|
||||
output = model(input, style)
|
||||
loss = loss_fn(output, target)
|
||||
loss = loss_fn(output, target, style=style)
|
||||
losses.append(loss.item())
|
||||
styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense
|
||||
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue