diff --git a/sCOCA_ML/train/train_gravpot.py b/sCOCA_ML/train/train_gravpot.py index 93d9677..6bd84e0 100644 --- a/sCOCA_ML/train/train_gravpot.py +++ b/sCOCA_ML/train/train_gravpot.py @@ -13,7 +13,8 @@ def train_model(model, print_timers=False, save_model_path=None, scheduler=None, - target_crop:int = None): + target_crop:int = None, + epoch_start:int = 0): """ Train a model with the given dataloader and optimizer. @@ -33,7 +34,7 @@ def train_model(model, - val_loss_log: List of validation losses for each epoch.""" if optimizer is None: - optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) if scheduler is None: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5) model.to(device) @@ -41,7 +42,7 @@ def train_model(model, train_loss_log = [] val_loss_log = [] - for epoch in range(num_epochs): + for epoch in range(epoch_start,num_epochs): model.train() progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch') io_time = 0.0 @@ -158,6 +159,28 @@ def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None): return losses.mean(), bin_means, bins +def resume_training(train_loss_log, val_loss_log, **kwargs): + """ + Resume training from the last epoch, updating the training and validation loss logs. + + Parameters: + - train_loss_log: List of training losses from previous epochs. + - val_loss_log: List of validation losses from previous epochs. + - kwargs: Additional parameters to pass to the training function. + + Returns: + - Updated train_loss_log and val_loss_log.""" + + if "epoch_start" not in kwargs: + kwargs["epoch_start"] = len(val_loss_log) + + train_loss_log2, val_loss_log2 = train_model(**kwargs) + train_loss_log.extend(train_loss_log2) + val_loss_log.extend(val_loss_log2) + + return train_loss_log, val_loss_log + + def train_models(models, dataloader,