resume training
This commit is contained in:
parent
0cbd7fcc46
commit
8058d81f26
1 changed files with 26 additions and 3 deletions
|
@ -13,7 +13,8 @@ def train_model(model,
|
|||
print_timers=False,
|
||||
save_model_path=None,
|
||||
scheduler=None,
|
||||
target_crop:int = None):
|
||||
target_crop:int = None,
|
||||
epoch_start:int = 0):
|
||||
"""
|
||||
Train a model with the given dataloader and optimizer.
|
||||
|
||||
|
@ -33,7 +34,7 @@ def train_model(model,
|
|||
- val_loss_log: List of validation losses for each epoch."""
|
||||
|
||||
if optimizer is None:
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
if scheduler is None:
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
||||
model.to(device)
|
||||
|
@ -41,7 +42,7 @@ def train_model(model,
|
|||
train_loss_log = []
|
||||
val_loss_log = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for epoch in range(epoch_start,num_epochs):
|
||||
model.train()
|
||||
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
|
||||
io_time = 0.0
|
||||
|
@ -158,6 +159,28 @@ def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None):
|
|||
return losses.mean(), bin_means, bins
|
||||
|
||||
|
||||
def resume_training(train_loss_log, val_loss_log, **kwargs):
|
||||
"""
|
||||
Resume training from the last epoch, updating the training and validation loss logs.
|
||||
|
||||
Parameters:
|
||||
- train_loss_log: List of training losses from previous epochs.
|
||||
- val_loss_log: List of validation losses from previous epochs.
|
||||
- kwargs: Additional parameters to pass to the training function.
|
||||
|
||||
Returns:
|
||||
- Updated train_loss_log and val_loss_log."""
|
||||
|
||||
if "epoch_start" not in kwargs:
|
||||
kwargs["epoch_start"] = len(val_loss_log)
|
||||
|
||||
train_loss_log2, val_loss_log2 = train_model(**kwargs)
|
||||
train_loss_log.extend(train_loss_log2)
|
||||
val_loss_log.extend(val_loss_log2)
|
||||
|
||||
return train_loss_log, val_loss_log
|
||||
|
||||
|
||||
|
||||
def train_models(models,
|
||||
dataloader,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue