many improvements

This commit is contained in:
Mayeul Aubin 2025-06-17 18:07:06 +02:00
parent c07ec8f8cf
commit 6c526d7115
4 changed files with 219 additions and 53 deletions

View file

@ -5,12 +5,39 @@ import torch
import time
from ..prepare_data.prepare_gravpot_data import prepare_data
def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', print_timers=False):
def train_model(model,
dataloader,
optimizer=None,
num_epochs=10,
device='cuda',
print_timers=False,
save_model_path=None,
scheduler=None):
"""
Train a model with the given dataloader and optimizer.
Parameters:
- model: The model to train.
- dataloader: A dictionary with 'train' and 'val' DataLoader objects.
- optimizer: The optimizer to use for training (default is Adam with lr=1e-3).
- num_epochs: Number of epochs to train the model (default is 10).
- device: Device to run the model on (default is 'cuda').
- print_timers: If True, print timing information for each epoch (default is False).
- save_model_path: If provided, the model will be saved to this path after each epoch.
- scheduler: Learning rate scheduler (optional).
Returns:
- train_loss_log: List of training losses for each batch.
- val_loss_log: List of validation losses for each epoch."""
if optimizer is None:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
if scheduler is None:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//4)
model.to(device)
loss_fn = torch.nn.MSELoss()
loss_log = []
train_loss_log = []
val_loss_log = []
for epoch in range(num_epochs):
model.train()
@ -20,7 +47,8 @@ def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda',
backward_time = 0.0
validation_time = 0.0
prev_time = time.time()
epoch_start_time = time.time()
prev_time = epoch_start_time # For I/O timing
for batch in progress_bar:
# I/O timer: time since last batch processed
t0 = time.time()
@ -45,32 +73,52 @@ def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda',
optimizer.step()
backward_time += time.time() - t2
loss_log.append((style[:, 0].detach().cpu().numpy(), loss.item()))
progress_bar.set_postfix(loss=loss.item())
train_loss_log.append(loss.item())
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
prev_time = time.time() # End of loop, for next I/O timing
# End of epoch, validate the model
t3 = time.time()
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device)
val_loss_log.append(val_loss)
validation_time += time.time() - t3
print(f"Validation Loss: {val_loss:.4f}")
# Prepare new samples for the next epoch
dataloader['train'].dataset.on_epoch_end()
dataloader['val'].dataset.on_epoch_end()
if save_model_path is not None:
torch.save(model.state_dict(), save_model_path+ f"_epoch_{epoch+1}.pth")
torch.save(dict(train_loss_log=train_loss_log,
val_loss_log=val_loss_log,
style_bins_means=style_bins_means,
style_bins=style_bins),
save_model_path + f"_epoch_{epoch+1}_stats.pth")
if scheduler is not None:
scheduler.step(val_loss)
print()
print(f"================ Epoch {epoch+1} Summary ================")
print(f"Validation Loss: {val_loss:2.6f}")
bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty
bins_str = "Style Bins: " + " | ".join([f"{b:>{bin_width}.2f}" for b in style_bins[:-1]])
means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2f}" for m in style_bins_means])
bins_str = "Style Bins: " + " | ".join([f" {b:>{bin_width}.2f} " for b in style_bins[:-1]])
means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2e}" for m in style_bins_means])
print(bins_str)
print(means_str)
print()
if print_timers:
total_time = io_time + forward_time + backward_time + validation_time
print(f"Epoch {epoch+1} Timings:")
print(f" I/O time: {io_time:.3f} s\t | {100 * io_time / total_time:.2f}%")
print(f" Forward time: {forward_time:.3f} s\t | {100 * forward_time / total_time:.2f}%")
print(f" Backward time: {backward_time:.3f} s\t | {100 * backward_time / total_time:.2f}%")
print(f" Validation time: {validation_time:.3f} s\t | {100 * validation_time / total_time:.2f}%")
total_time = time.time() - epoch_start_time
print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s")
print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%")
print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%")
print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%")
print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%")
print()
return loss_log
return train_loss_log, val_loss_log
def validate(model, val_loader, loss_fn, device='cuda'):
@ -89,11 +137,11 @@ def validate(model, val_loader, loss_fn, device='cuda'):
output = model(input, style)
loss = loss_fn(output, target)
losses.append(loss.item())
styles.append(style[:, 0].cpu().numpy())
progress_bar.set_postfix(loss=loss.item())
styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
# Bin loss by style[0]
styles = np.concatenate(styles)
styles = np.array(styles)
losses = np.array(losses)
bins = np.linspace(styles.min(), styles.max(), 10)
digitized = np.digitize(styles, bins)