expansion

This commit is contained in:
Mayeul Aubin 2025-06-06 13:52:17 +02:00
parent 24c2d546db
commit c07ec8f8cf
6 changed files with 138 additions and 1 deletions

View file

@ -0,0 +1,102 @@
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import time
from ..prepare_data.prepare_gravpot_data import prepare_data
def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', print_timers=False):
if optimizer is None:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.to(device)
loss_fn = torch.nn.MSELoss()
loss_log = []
for epoch in range(num_epochs):
model.train()
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
io_time = 0.0
forward_time = 0.0
backward_time = 0.0
validation_time = 0.0
prev_time = time.time()
for batch in progress_bar:
# I/O timer: time since last batch processed
t0 = time.time()
io_time += t0 - prev_time
batch = prepare_data(batch)
input = batch['input'].to(device)
target = batch['target'].to(device)
style = batch['style'].to(device)
optimizer.zero_grad()
# Forward pass
t1 = time.time()
output = model(input, style)
loss = loss_fn(output, target)
forward_time += time.time() - t1
# Backward pass and optimization
t2 = time.time()
loss.backward()
optimizer.step()
backward_time += time.time() - t2
loss_log.append((style[:, 0].detach().cpu().numpy(), loss.item()))
progress_bar.set_postfix(loss=loss.item())
prev_time = time.time() # End of loop, for next I/O timing
# End of epoch, validate the model
t3 = time.time()
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device)
validation_time += time.time() - t3
print(f"Validation Loss: {val_loss:.4f}")
bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty
bins_str = "Style Bins: " + " | ".join([f"{b:>{bin_width}.2f}" for b in style_bins[:-1]])
means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2f}" for m in style_bins_means])
print(bins_str)
print(means_str)
if print_timers:
total_time = io_time + forward_time + backward_time + validation_time
print(f"Epoch {epoch+1} Timings:")
print(f" I/O time: {io_time:.3f} s\t | {100 * io_time / total_time:.2f}%")
print(f" Forward time: {forward_time:.3f} s\t | {100 * forward_time / total_time:.2f}%")
print(f" Backward time: {backward_time:.3f} s\t | {100 * backward_time / total_time:.2f}%")
print(f" Validation time: {validation_time:.3f} s\t | {100 * validation_time / total_time:.2f}%")
return loss_log
def validate(model, val_loader, loss_fn, device='cuda'):
model.eval()
losses = []
styles = []
progress_bar = tqdm(val_loader, desc="Validation", unit='batch')
with torch.no_grad():
for batch in progress_bar:
batch = prepare_data(batch)
input = batch['input'].to(device)
target = batch['target'].to(device)
style = batch['style'].to(device)
output = model(input, style)
loss = loss_fn(output, target)
losses.append(loss.item())
styles.append(style[:, 0].cpu().numpy())
progress_bar.set_postfix(loss=loss.item())
# Bin loss by style[0]
styles = np.concatenate(styles)
losses = np.array(losses)
bins = np.linspace(styles.min(), styles.max(), 10)
digitized = np.digitize(styles, bins)
bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))]
return losses.mean(), bin_means, bins