import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import torch import time from ..prepare_data.prepare_gravpot_data import prepare_data def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', print_timers=False, save_model_path=None, scheduler=None, target_crop:int = None, epoch_start:int = 0): """ Train a model with the given dataloader and optimizer. Parameters: - model: The model to train. - dataloader: A dictionary with 'train' and 'val' DataLoader objects. - optimizer: The optimizer to use for training (default is Adam with lr=1e-3). - num_epochs: Number of epochs to train the model (default is 10). - device: Device to run the model on (default is 'cuda'). - print_timers: If True, print timing information for each epoch (default is False). - save_model_path: If provided, the model will be saved to this path after each epoch. - scheduler: Learning rate scheduler (optional). - target_crop: If provided, the target will be cropped by this amount from each side (default is None, no cropping). Returns: - train_loss_log: List of training losses for each batch. - val_loss_log: List of validation losses for each epoch.""" if optimizer is None: optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) if scheduler is None: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5) model.to(device) loss_fn = torch.nn.MSELoss() train_loss_log = [] val_loss_log = [] for epoch in range(epoch_start,num_epochs): model.train() progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch') io_time = 0.0 forward_time = 0.0 backward_time = 0.0 validation_time = 0.0 epoch_start_time = time.time() prev_time = epoch_start_time # For I/O timing for batch in progress_bar: # I/O timer: time since last batch processed t0 = time.time() io_time += t0 - prev_time batch = prepare_data(batch) input = batch['input'].to(device) target = batch['target'].to(device) style = batch['style'].to(device) optimizer.zero_grad() # Forward pass t1 = time.time() output = model(input, style) if target_crop: target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop] loss = loss_fn(output, target) forward_time += time.time() - t1 # Backward pass and optimization t2 = time.time() loss.backward() optimizer.step() backward_time += time.time() - t2 train_loss_log.append(loss.item()) progress_bar.set_postfix(loss=f"{loss.item():2.5f}") prev_time = time.time() # End of loop, for next I/O timing # End of epoch, validate the model t3 = time.time() val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, target_crop=target_crop) val_loss_log.append(val_loss) validation_time += time.time() - t3 # Prepare new samples for the next epoch dataloader['train'].dataset.on_epoch_end() dataloader['val'].dataset.on_epoch_end() if save_model_path is not None: torch.save(model.state_dict(), save_model_path+ f"_epoch_{epoch+1}.pth") torch.save(dict(train_loss_log=train_loss_log, val_loss_log=val_loss_log, style_bins_means=style_bins_means, style_bins=style_bins), save_model_path + f"_epoch_{epoch+1}_stats.pth") if scheduler is not None: scheduler.step(val_loss) print() print(f"================ Epoch {epoch+1} Summary ================") print(f"Validation Loss: {val_loss:2.6f}") bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty bins_str = "Style Bins: " + " | ".join([f" {b:>{bin_width}.2f} " for b in style_bins[:-1]]) means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2e}" for m in style_bins_means]) print(bins_str) print(means_str) print() if print_timers: total_time = time.time() - epoch_start_time print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s") print(f" Sync time + I/O: {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%") print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%") print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%") print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%") print() return train_loss_log, val_loss_log def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None): model.eval() losses = [] styles = [] progress_bar = tqdm(val_loader, desc="Validation", unit='batch') with torch.no_grad(): for batch in progress_bar: batch = prepare_data(batch) input = batch['input'].to(device) target = batch['target'].to(device) style = batch['style'].to(device) if target_crop: target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop] output = model(input, style) loss = loss_fn(output, target) losses.append(loss.item()) styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense progress_bar.set_postfix(loss=f"{loss.item():2.5f}") # Bin loss by style[0] styles = np.array(styles) losses = np.array(losses) bins = np.linspace(styles.min(), styles.max(), 10) digitized = np.digitize(styles, bins) bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))] return losses.mean(), bin_means, bins def resume_training(train_loss_log, val_loss_log, **kwargs): """ Resume training from the last epoch, updating the training and validation loss logs. Parameters: - train_loss_log: List of training losses from previous epochs. - val_loss_log: List of validation losses from previous epochs. - kwargs: Additional parameters to pass to the training function. Returns: - Updated train_loss_log and val_loss_log.""" if "epoch_start" not in kwargs: kwargs["epoch_start"] = len(val_loss_log) train_loss_log2, val_loss_log2 = train_model(**kwargs) train_loss_log.extend(train_loss_log2) val_loss_log.extend(val_loss_log2) return train_loss_log, val_loss_log def train_models(models, dataloader, optimizers=None, num_epochs=10, device='cuda', print_timers=False, save_model_paths=None, schedulers=None): """ Train multiple models with their respective dataloaders and optimizers. This is useful since the main bottelneck is I/O, so training multiple models on the same data loaded. Parameters: - models: List of models to train. - dataloader: Dictionnary with 'train' and 'val' DataLoader objects. - optimizers: List of optimizers for each model (default is Adam with lr=1e-3). - num_epochs: Number of epochs to train the models (default is 10). - device: Device to run the models on (default is 'cuda'). - print_timers: If True, print timing information for each epoch (default is False). - save_model_paths: List of paths to save the models after each epoch. - schedulers: List of learning rate schedulers for each model (optional). Returns: - train_loss_logs: List of training losses for each model. - val_loss_logs: List of validation losses for each model.""" if optimizers is None: optimizers = [torch.optim.Adam(model.parameters(), lr=1e-4) for model in models] if schedulers is None: schedulers = [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5) for optimizer in optimizers] models = [model.to(device) for model in models] loss_fns = [torch.nn.MSELoss() for _ in models] train_loss_logs = [[] for _ in models] val_loss_logs = [[] for _ in models] if save_model_paths is None: save_model_paths = [None] * len(models) if len(save_model_paths) != len(models) or len(optimizers) != len(models) or len(schedulers) != len(models): raise ValueError("Length of save_model_paths, optimizers, and schedulers must match the number of models.") print(f"Starting training for {len(models)} models...") for epoch in range(num_epochs): for model in models: model.train() progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch') io_time = 0.0 forward_time = 0.0 backward_time = 0.0 validation_time = 0.0 epoch_start_time = time.time() prev_time = epoch_start_time for batch in progress_bar: # I/O timer: time since last batch processed t0 = time.time() io_time += t0 - prev_time batch = prepare_data(batch) input = batch['input'].to(device) target = batch['target'].to(device) style = batch['style'].to(device) # Loop on models for training for i, model in enumerate(models): optimizers[i].zero_grad() # Forward pass t1 = time.time() output = model(input, style) loss = loss_fns[i](output, target) forward_time += time.time() - t1 # Backward pass and optimization t2 = time.time() loss.backward() optimizers[i].step() backward_time += time.time() - t2 train_loss_logs[i].append(loss.item()) progress_bar.set_postfix(loss=f"{loss.item():2.5f}") prev_time = time.time() # End of epoch, validate the models t3 = time.time() for i, model in enumerate(models): val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fns[i], device) val_loss_logs[i].append(val_loss) if save_model_paths[i] is not None: torch.save(model.state_dict(), save_model_paths[i] + f"_epoch_{epoch+1}.pth") torch.save(dict(train_loss_log=train_loss_logs[i], val_loss_log=val_loss_logs[i], style_bins_means=style_bins_means, style_bins=style_bins), save_model_paths[i] + f"_epoch_{epoch+1}_stats.pth") if schedulers[i] is not None: schedulers[i].step(val_loss) validation_time += time.time() - t3 # Prepare new samples for the next epoch dataloader['train'].dataset.on_epoch_end() dataloader['val'].dataset.on_epoch_end() print() print(f"================ Epoch {epoch+1} Summary ================") for i, model in enumerate(models): print(f"Model {i+1} Validation Loss: {val_loss_logs[i][-1]:2.6f}") if print_timers: total_time = time.time() - epoch_start_time print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s") print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%") print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%") print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%") print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%") print() print("Training complete.") return train_loss_logs, val_loss_logs