import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm import torch import time from ..prepare_data.prepare_gravpot_data import prepare_data def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', print_timers=False): if optimizer is None: optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) model.to(device) loss_fn = torch.nn.MSELoss() loss_log = [] for epoch in range(num_epochs): model.train() progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch') io_time = 0.0 forward_time = 0.0 backward_time = 0.0 validation_time = 0.0 prev_time = time.time() for batch in progress_bar: # I/O timer: time since last batch processed t0 = time.time() io_time += t0 - prev_time batch = prepare_data(batch) input = batch['input'].to(device) target = batch['target'].to(device) style = batch['style'].to(device) optimizer.zero_grad() # Forward pass t1 = time.time() output = model(input, style) loss = loss_fn(output, target) forward_time += time.time() - t1 # Backward pass and optimization t2 = time.time() loss.backward() optimizer.step() backward_time += time.time() - t2 loss_log.append((style[:, 0].detach().cpu().numpy(), loss.item())) progress_bar.set_postfix(loss=loss.item()) prev_time = time.time() # End of loop, for next I/O timing # End of epoch, validate the model t3 = time.time() val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device) validation_time += time.time() - t3 print(f"Validation Loss: {val_loss:.4f}") bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty bins_str = "Style Bins: " + " | ".join([f"{b:>{bin_width}.2f}" for b in style_bins[:-1]]) means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2f}" for m in style_bins_means]) print(bins_str) print(means_str) if print_timers: total_time = io_time + forward_time + backward_time + validation_time print(f"Epoch {epoch+1} Timings:") print(f" I/O time: {io_time:.3f} s\t | {100 * io_time / total_time:.2f}%") print(f" Forward time: {forward_time:.3f} s\t | {100 * forward_time / total_time:.2f}%") print(f" Backward time: {backward_time:.3f} s\t | {100 * backward_time / total_time:.2f}%") print(f" Validation time: {validation_time:.3f} s\t | {100 * validation_time / total_time:.2f}%") return loss_log def validate(model, val_loader, loss_fn, device='cuda'): model.eval() losses = [] styles = [] progress_bar = tqdm(val_loader, desc="Validation", unit='batch') with torch.no_grad(): for batch in progress_bar: batch = prepare_data(batch) input = batch['input'].to(device) target = batch['target'].to(device) style = batch['style'].to(device) output = model(input, style) loss = loss_fn(output, target) losses.append(loss.item()) styles.append(style[:, 0].cpu().numpy()) progress_bar.set_postfix(loss=loss.item()) # Bin loss by style[0] styles = np.concatenate(styles) losses = np.array(losses) bins = np.linspace(styles.min(), styles.max(), 10) digitized = np.digitize(styles, bins) bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))] return losses.mean(), bin_means, bins