resume training

This commit is contained in:
Mayeul Aubin 2025-06-30 14:30:21 +02:00
parent 0cbd7fcc46
commit 8058d81f26

View file

@ -13,7 +13,8 @@ def train_model(model,
print_timers=False,
save_model_path=None,
scheduler=None,
target_crop:int = None):
target_crop:int = None,
epoch_start:int = 0):
"""
Train a model with the given dataloader and optimizer.
@ -33,7 +34,7 @@ def train_model(model,
- val_loss_log: List of validation losses for each epoch."""
if optimizer is None:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
if scheduler is None:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
model.to(device)
@ -41,7 +42,7 @@ def train_model(model,
train_loss_log = []
val_loss_log = []
for epoch in range(num_epochs):
for epoch in range(epoch_start,num_epochs):
model.train()
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
io_time = 0.0
@ -158,6 +159,28 @@ def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None):
return losses.mean(), bin_means, bins
def resume_training(train_loss_log, val_loss_log, **kwargs):
"""
Resume training from the last epoch, updating the training and validation loss logs.
Parameters:
- train_loss_log: List of training losses from previous epochs.
- val_loss_log: List of validation losses from previous epochs.
- kwargs: Additional parameters to pass to the training function.
Returns:
- Updated train_loss_log and val_loss_log."""
if "epoch_start" not in kwargs:
kwargs["epoch_start"] = len(val_loss_log)
train_loss_log2, val_loss_log2 = train_model(**kwargs)
train_loss_log.extend(train_loss_log2)
val_loss_log.extend(val_loss_log2)
return train_loss_log, val_loss_log
def train_models(models,
dataloader,