resume training
This commit is contained in:
parent
0cbd7fcc46
commit
8058d81f26
1 changed files with 26 additions and 3 deletions
|
@ -13,7 +13,8 @@ def train_model(model,
|
||||||
print_timers=False,
|
print_timers=False,
|
||||||
save_model_path=None,
|
save_model_path=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
target_crop:int = None):
|
target_crop:int = None,
|
||||||
|
epoch_start:int = 0):
|
||||||
"""
|
"""
|
||||||
Train a model with the given dataloader and optimizer.
|
Train a model with the given dataloader and optimizer.
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ def train_model(model,
|
||||||
- val_loss_log: List of validation losses for each epoch."""
|
- val_loss_log: List of validation losses for each epoch."""
|
||||||
|
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
if scheduler is None:
|
if scheduler is None:
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
@ -41,7 +42,7 @@ def train_model(model,
|
||||||
train_loss_log = []
|
train_loss_log = []
|
||||||
val_loss_log = []
|
val_loss_log = []
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(epoch_start,num_epochs):
|
||||||
model.train()
|
model.train()
|
||||||
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
|
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
|
||||||
io_time = 0.0
|
io_time = 0.0
|
||||||
|
@ -158,6 +159,28 @@ def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None):
|
||||||
return losses.mean(), bin_means, bins
|
return losses.mean(), bin_means, bins
|
||||||
|
|
||||||
|
|
||||||
|
def resume_training(train_loss_log, val_loss_log, **kwargs):
|
||||||
|
"""
|
||||||
|
Resume training from the last epoch, updating the training and validation loss logs.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- train_loss_log: List of training losses from previous epochs.
|
||||||
|
- val_loss_log: List of validation losses from previous epochs.
|
||||||
|
- kwargs: Additional parameters to pass to the training function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Updated train_loss_log and val_loss_log."""
|
||||||
|
|
||||||
|
if "epoch_start" not in kwargs:
|
||||||
|
kwargs["epoch_start"] = len(val_loss_log)
|
||||||
|
|
||||||
|
train_loss_log2, val_loss_log2 = train_model(**kwargs)
|
||||||
|
train_loss_log.extend(train_loss_log2)
|
||||||
|
val_loss_log.extend(val_loss_log2)
|
||||||
|
|
||||||
|
return train_loss_log, val_loss_log
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train_models(models,
|
def train_models(models,
|
||||||
dataloader,
|
dataloader,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue