314 lines
12 KiB
Python
314 lines
12 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
import torch
|
|
import time
|
|
from ..prepare_data.prepare_gravpot_data import prepare_data
|
|
|
|
def train_model(model,
|
|
dataloader,
|
|
optimizer=None,
|
|
num_epochs=10,
|
|
device='cuda',
|
|
print_timers=False,
|
|
save_model_path=None,
|
|
scheduler=None,
|
|
target_crop:int = None,
|
|
epoch_start:int = 0):
|
|
"""
|
|
Train a model with the given dataloader and optimizer.
|
|
|
|
Parameters:
|
|
- model: The model to train.
|
|
- dataloader: A dictionary with 'train' and 'val' DataLoader objects.
|
|
- optimizer: The optimizer to use for training (default is Adam with lr=1e-3).
|
|
- num_epochs: Number of epochs to train the model (default is 10).
|
|
- device: Device to run the model on (default is 'cuda').
|
|
- print_timers: If True, print timing information for each epoch (default is False).
|
|
- save_model_path: If provided, the model will be saved to this path after each epoch.
|
|
- scheduler: Learning rate scheduler (optional).
|
|
- target_crop: If provided, the target will be cropped by this amount from each side (default is None, no cropping).
|
|
|
|
Returns:
|
|
- train_loss_log: List of training losses for each batch.
|
|
- val_loss_log: List of validation losses for each epoch."""
|
|
|
|
if optimizer is None:
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
if scheduler is None:
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
|
model.to(device)
|
|
loss_fn = torch.nn.MSELoss()
|
|
train_loss_log = []
|
|
val_loss_log = []
|
|
|
|
for epoch in range(epoch_start,num_epochs):
|
|
model.train()
|
|
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
|
|
io_time = 0.0
|
|
forward_time = 0.0
|
|
backward_time = 0.0
|
|
validation_time = 0.0
|
|
|
|
epoch_start_time = time.time()
|
|
prev_time = epoch_start_time # For I/O timing
|
|
for batch in progress_bar:
|
|
# I/O timer: time since last batch processed
|
|
t0 = time.time()
|
|
io_time += t0 - prev_time
|
|
|
|
batch = prepare_data(batch)
|
|
input = batch['input'].to(device)
|
|
target = batch['target'].to(device)
|
|
style = batch['style'].to(device)
|
|
|
|
optimizer.zero_grad()
|
|
|
|
# Forward pass
|
|
t1 = time.time()
|
|
output = model(input, style)
|
|
|
|
if target_crop:
|
|
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
|
loss = loss_fn(output, target)
|
|
forward_time += time.time() - t1
|
|
|
|
# Backward pass and optimization
|
|
t2 = time.time()
|
|
loss.backward()
|
|
optimizer.step()
|
|
backward_time += time.time() - t2
|
|
|
|
train_loss_log.append(loss.item())
|
|
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
|
|
|
prev_time = time.time() # End of loop, for next I/O timing
|
|
|
|
# End of epoch, validate the model
|
|
t3 = time.time()
|
|
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, target_crop=target_crop)
|
|
val_loss_log.append(val_loss)
|
|
validation_time += time.time() - t3
|
|
|
|
# Prepare new samples for the next epoch
|
|
dataloader['train'].dataset.on_epoch_end()
|
|
dataloader['val'].dataset.on_epoch_end()
|
|
|
|
if save_model_path is not None:
|
|
torch.save(model.state_dict(), save_model_path+ f"_epoch_{epoch+1}.pth")
|
|
torch.save(dict(train_loss_log=train_loss_log,
|
|
val_loss_log=val_loss_log,
|
|
style_bins_means=style_bins_means,
|
|
style_bins=style_bins),
|
|
save_model_path + f"_epoch_{epoch+1}_stats.pth")
|
|
|
|
if scheduler is not None:
|
|
scheduler.step(val_loss)
|
|
|
|
print()
|
|
print(f"================ Epoch {epoch+1} Summary ================")
|
|
print(f"Validation Loss: {val_loss:2.6f}")
|
|
bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty
|
|
bins_str = "Style Bins: " + " | ".join([f" {b:>{bin_width}.2f} " for b in style_bins[:-1]])
|
|
means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2e}" for m in style_bins_means])
|
|
print(bins_str)
|
|
print(means_str)
|
|
print()
|
|
|
|
if print_timers:
|
|
total_time = time.time() - epoch_start_time
|
|
print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s")
|
|
print(f" Sync time + I/O: {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%")
|
|
print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%")
|
|
print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%")
|
|
print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%")
|
|
print()
|
|
|
|
return train_loss_log, val_loss_log
|
|
|
|
|
|
def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None):
|
|
model.eval()
|
|
losses = []
|
|
styles = []
|
|
progress_bar = tqdm(val_loader, desc="Validation", unit='batch')
|
|
|
|
with torch.no_grad():
|
|
for batch in progress_bar:
|
|
batch = prepare_data(batch)
|
|
input = batch['input'].to(device)
|
|
target = batch['target'].to(device)
|
|
style = batch['style'].to(device)
|
|
|
|
if target_crop:
|
|
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
|
|
|
output = model(input, style)
|
|
loss = loss_fn(output, target)
|
|
losses.append(loss.item())
|
|
styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense
|
|
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
|
|
|
# Bin loss by style[0]
|
|
styles = np.array(styles)
|
|
losses = np.array(losses)
|
|
bins = np.linspace(styles.min(), styles.max(), 10)
|
|
digitized = np.digitize(styles, bins)
|
|
bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))]
|
|
|
|
return losses.mean(), bin_means, bins
|
|
|
|
|
|
def resume_training(train_loss_log, val_loss_log, **kwargs):
|
|
"""
|
|
Resume training from the last epoch, updating the training and validation loss logs.
|
|
|
|
Parameters:
|
|
- train_loss_log: List of training losses from previous epochs.
|
|
- val_loss_log: List of validation losses from previous epochs.
|
|
- kwargs: Additional parameters to pass to the training function.
|
|
|
|
Returns:
|
|
- Updated train_loss_log and val_loss_log."""
|
|
|
|
if "epoch_start" not in kwargs:
|
|
kwargs["epoch_start"] = len(val_loss_log)
|
|
|
|
train_loss_log2, val_loss_log2 = train_model(**kwargs)
|
|
train_loss_log.extend(train_loss_log2)
|
|
val_loss_log.extend(val_loss_log2)
|
|
|
|
return train_loss_log, val_loss_log
|
|
|
|
|
|
|
|
def train_models(models,
|
|
dataloader,
|
|
optimizers=None,
|
|
num_epochs=10,
|
|
device='cuda',
|
|
print_timers=False,
|
|
save_model_paths=None,
|
|
schedulers=None):
|
|
"""
|
|
Train multiple models with their respective dataloaders and optimizers.
|
|
This is useful since the main bottelneck is I/O, so training multiple models on the same data loaded.
|
|
|
|
Parameters:
|
|
- models: List of models to train.
|
|
- dataloader: Dictionnary with 'train' and 'val' DataLoader objects.
|
|
- optimizers: List of optimizers for each model (default is Adam with lr=1e-3).
|
|
- num_epochs: Number of epochs to train the models (default is 10).
|
|
- device: Device to run the models on (default is 'cuda').
|
|
- print_timers: If True, print timing information for each epoch (default is False).
|
|
- save_model_paths: List of paths to save the models after each epoch.
|
|
- schedulers: List of learning rate schedulers for each model (optional).
|
|
|
|
Returns:
|
|
- train_loss_logs: List of training losses for each model.
|
|
- val_loss_logs: List of validation losses for each model."""
|
|
|
|
if optimizers is None:
|
|
optimizers = [torch.optim.Adam(model.parameters(), lr=1e-4) for model in models]
|
|
|
|
if schedulers is None:
|
|
schedulers = [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
|
for optimizer in optimizers]
|
|
|
|
models = [model.to(device) for model in models]
|
|
|
|
loss_fns = [torch.nn.MSELoss() for _ in models]
|
|
|
|
train_loss_logs = [[] for _ in models]
|
|
val_loss_logs = [[] for _ in models]
|
|
|
|
if save_model_paths is None:
|
|
save_model_paths = [None] * len(models)
|
|
|
|
if len(save_model_paths) != len(models) or len(optimizers) != len(models) or len(schedulers) != len(models):
|
|
raise ValueError("Length of save_model_paths, optimizers, and schedulers must match the number of models.")
|
|
|
|
print(f"Starting training for {len(models)} models...")
|
|
|
|
for epoch in range(num_epochs):
|
|
|
|
for model in models:
|
|
model.train()
|
|
|
|
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
|
|
io_time = 0.0
|
|
forward_time = 0.0
|
|
backward_time = 0.0
|
|
validation_time = 0.0
|
|
epoch_start_time = time.time()
|
|
prev_time = epoch_start_time
|
|
|
|
for batch in progress_bar:
|
|
# I/O timer: time since last batch processed
|
|
t0 = time.time()
|
|
io_time += t0 - prev_time
|
|
|
|
batch = prepare_data(batch)
|
|
input = batch['input'].to(device)
|
|
target = batch['target'].to(device)
|
|
style = batch['style'].to(device)
|
|
|
|
# Loop on models for training
|
|
for i, model in enumerate(models):
|
|
optimizers[i].zero_grad()
|
|
|
|
# Forward pass
|
|
t1 = time.time()
|
|
output = model(input, style)
|
|
loss = loss_fns[i](output, target)
|
|
forward_time += time.time() - t1
|
|
|
|
# Backward pass and optimization
|
|
t2 = time.time()
|
|
loss.backward()
|
|
optimizers[i].step()
|
|
backward_time += time.time() - t2
|
|
|
|
train_loss_logs[i].append(loss.item())
|
|
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
|
|
|
prev_time = time.time()
|
|
|
|
# End of epoch, validate the models
|
|
t3 = time.time()
|
|
for i, model in enumerate(models):
|
|
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fns[i], device)
|
|
val_loss_logs[i].append(val_loss)
|
|
|
|
if save_model_paths[i] is not None:
|
|
torch.save(model.state_dict(), save_model_paths[i] + f"_epoch_{epoch+1}.pth")
|
|
torch.save(dict(train_loss_log=train_loss_logs[i],
|
|
val_loss_log=val_loss_logs[i],
|
|
style_bins_means=style_bins_means,
|
|
style_bins=style_bins),
|
|
save_model_paths[i] + f"_epoch_{epoch+1}_stats.pth")
|
|
|
|
if schedulers[i] is not None:
|
|
schedulers[i].step(val_loss)
|
|
validation_time += time.time() - t3
|
|
|
|
# Prepare new samples for the next epoch
|
|
dataloader['train'].dataset.on_epoch_end()
|
|
dataloader['val'].dataset.on_epoch_end()
|
|
|
|
print()
|
|
print(f"================ Epoch {epoch+1} Summary ================")
|
|
for i, model in enumerate(models):
|
|
print(f"Model {i+1} Validation Loss: {val_loss_logs[i][-1]:2.6f}")
|
|
|
|
if print_timers:
|
|
total_time = time.time() - epoch_start_time
|
|
print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s")
|
|
print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%")
|
|
print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%")
|
|
print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%")
|
|
print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%")
|
|
print()
|
|
|
|
print("Training complete.")
|
|
return train_loss_logs, val_loss_logs
|