diff --git a/sCOCA_ML/dataset/momenta_dataset.py b/sCOCA_ML/dataset/momenta_dataset.py index d8ecf1f..ffa062f 100644 --- a/sCOCA_ML/dataset/momenta_dataset.py +++ b/sCOCA_ML/dataset/momenta_dataset.py @@ -40,7 +40,7 @@ class MomentaDataset(Dataset): initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'], target_variable:str='gravpot', style_files:str='cosmo_and_time_parameters', - style_keys:list|None=["D1", "D2"], + style_keys:list|None=["a"], max_time:int=100): """ Dataset for residual momenta data (COCA). diff --git a/sCOCA_ML/prepare_data/prepare_momenta_data.py b/sCOCA_ML/prepare_data/prepare_momenta_data.py new file mode 100644 index 0000000..c121d1a --- /dev/null +++ b/sCOCA_ML/prepare_data/prepare_momenta_data.py @@ -0,0 +1,29 @@ +from .utils_functions import * + +def prepare_data(batch, + scale_phi_ini:float = 1000.0, + scale_delta_ini:float = 12.0, + scale_target:float = 50.0, + lin_threshold_target:float = 2., + ): + + # delta_ini = batch['input'][:, [0], :, :, :] + # phi_ini = batch['input'][:, [1], :, :, :] + a = batch['style'][:, [0], None, None, None] + momenta = batch['target'][:, [0], :, :, :] + + + _input = batch['input'] + _input[:, 0, :, :, :] /= scale_delta_ini + _input[:, 1, :, :, :] /= scale_phi_ini + + _target = momenta / (1.14*a)**(2.31) + + _style = batch['style'] + + return { + 'input': _input, + 'target': _target, + 'style': _style + } + \ No newline at end of file diff --git a/sCOCA_ML/train/train_momenta.py b/sCOCA_ML/train/train_momenta.py new file mode 100644 index 0000000..555d6cb --- /dev/null +++ b/sCOCA_ML/train/train_momenta.py @@ -0,0 +1,299 @@ +import numpy as np +import matplotlib.pyplot as plt +from tqdm import tqdm +import torch +import time +from ..prepare_data.prepare_gravpot_data import prepare_data + +def train_model(model, + dataloader, + optimizer=None, + num_epochs=10, + device='cuda', + print_timers=False, + save_model_path=None, + scheduler=None, + target_crop:int = None): + """ + Train a model with the given dataloader and optimizer. + + Parameters: + - model: The model to train. + - dataloader: A dictionary with 'train' and 'val' DataLoader objects. + - optimizer: The optimizer to use for training (default is Adam with lr=1e-3). + - num_epochs: Number of epochs to train the model (default is 10). + - device: Device to run the model on (default is 'cuda'). + - print_timers: If True, print timing information for each epoch (default is False). + - save_model_path: If provided, the model will be saved to this path after each epoch. + - scheduler: Learning rate scheduler (optional). + - target_crop: If provided, the target will be cropped by this amount from each side (default is None, no cropping). + + Returns: + - train_loss_log: List of training losses for each batch. + - val_loss_log: List of validation losses for each epoch.""" + + if optimizer is None: + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + if scheduler is None: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5) + model.to(device) + + def loss_fn(output, target): + """MSE + MSE on the norm of the vector field (channels)""" + mse_loss = torch.nn.functional.mse_loss(output, target) + norm_output = torch.norm(output, dim=1, keepdim=True) + norm_target = torch.norm(target, dim=1, keepdim=True) + norm_loss = torch.nn.functional.mse_loss(norm_output, norm_target) + return mse_loss + norm_loss + + train_loss_log = [] + val_loss_log = [] + + for epoch in range(num_epochs): + model.train() + progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch') + io_time = 0.0 + forward_time = 0.0 + backward_time = 0.0 + validation_time = 0.0 + + epoch_start_time = time.time() + prev_time = epoch_start_time # For I/O timing + for batch in progress_bar: + # I/O timer: time since last batch processed + t0 = time.time() + io_time += t0 - prev_time + + batch = prepare_data(batch) + input = batch['input'].to(device) + target = batch['target'].to(device) + style = batch['style'].to(device) + + optimizer.zero_grad() + + # Forward pass + t1 = time.time() + output = model(input, style) + + if target_crop: + target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop] + loss = loss_fn(output, target) + forward_time += time.time() - t1 + + # Backward pass and optimization + t2 = time.time() + loss.backward() + optimizer.step() + backward_time += time.time() - t2 + + train_loss_log.append(loss.item()) + progress_bar.set_postfix(loss=f"{loss.item():2.5f}") + + prev_time = time.time() # End of loop, for next I/O timing + + # End of epoch, validate the model + t3 = time.time() + val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, target_crop=target_crop) + val_loss_log.append(val_loss) + validation_time += time.time() - t3 + + # Prepare new samples for the next epoch + dataloader['train'].dataset.on_epoch_end() + dataloader['val'].dataset.on_epoch_end() + + if save_model_path is not None: + torch.save(model.state_dict(), save_model_path+ f"_epoch_{epoch+1}.pth") + torch.save(dict(train_loss_log=train_loss_log, + val_loss_log=val_loss_log, + style_bins_means=style_bins_means, + style_bins=style_bins), + save_model_path + f"_epoch_{epoch+1}_stats.pth") + + if scheduler is not None: + scheduler.step(val_loss) + + print() + print(f"================ Epoch {epoch+1} Summary ================") + print(f"Validation Loss: {val_loss:2.6f}") + bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty + bins_str = "Style Bins: " + " | ".join([f" {b:>{bin_width}.2f} " for b in style_bins[:-1]]) + means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2e}" for m in style_bins_means]) + print(bins_str) + print(means_str) + print() + + if print_timers: + total_time = time.time() - epoch_start_time + print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s") + print(f" Sync time + I/O: {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%") + print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%") + print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%") + print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%") + print() + + return train_loss_log, val_loss_log + + +def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None): + model.eval() + losses = [] + styles = [] + progress_bar = tqdm(val_loader, desc="Validation", unit='batch') + + with torch.no_grad(): + for batch in progress_bar: + batch = prepare_data(batch) + input = batch['input'].to(device) + target = batch['target'].to(device) + style = batch['style'].to(device) + + if target_crop: + target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop] + + output = model(input, style) + loss = loss_fn(output, target) + losses.append(loss.item()) + styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense + progress_bar.set_postfix(loss=f"{loss.item():2.5f}") + + # Bin loss by style[0] + styles = np.array(styles) + losses = np.array(losses) + bins = np.linspace(styles.min(), styles.max(), 10) + digitized = np.digitize(styles, bins) + bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))] + + return losses.mean(), bin_means, bins + + + +def train_models(models, + dataloader, + optimizers=None, + num_epochs=10, + device='cuda', + print_timers=False, + save_model_paths=None, + schedulers=None): + """ + Train multiple models with their respective dataloaders and optimizers. + This is useful since the main bottelneck is I/O, so training multiple models on the same data loaded. + + Parameters: + - models: List of models to train. + - dataloader: Dictionnary with 'train' and 'val' DataLoader objects. + - optimizers: List of optimizers for each model (default is Adam with lr=1e-3). + - num_epochs: Number of epochs to train the models (default is 10). + - device: Device to run the models on (default is 'cuda'). + - print_timers: If True, print timing information for each epoch (default is False). + - save_model_paths: List of paths to save the models after each epoch. + - schedulers: List of learning rate schedulers for each model (optional). + + Returns: + - train_loss_logs: List of training losses for each model. + - val_loss_logs: List of validation losses for each model.""" + + if optimizers is None: + optimizers = [torch.optim.Adam(model.parameters(), lr=1e-4) for model in models] + + if schedulers is None: + schedulers = [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5) + for optimizer in optimizers] + + models = [model.to(device) for model in models] + + loss_fns = [torch.nn.MSELoss() for _ in models] + + train_loss_logs = [[] for _ in models] + val_loss_logs = [[] for _ in models] + + if save_model_paths is None: + save_model_paths = [None] * len(models) + + if len(save_model_paths) != len(models) or len(optimizers) != len(models) or len(schedulers) != len(models): + raise ValueError("Length of save_model_paths, optimizers, and schedulers must match the number of models.") + + print(f"Starting training for {len(models)} models...") + + for epoch in range(num_epochs): + + for model in models: + model.train() + + progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch') + io_time = 0.0 + forward_time = 0.0 + backward_time = 0.0 + validation_time = 0.0 + epoch_start_time = time.time() + prev_time = epoch_start_time + + for batch in progress_bar: + # I/O timer: time since last batch processed + t0 = time.time() + io_time += t0 - prev_time + + batch = prepare_data(batch) + input = batch['input'].to(device) + target = batch['target'].to(device) + style = batch['style'].to(device) + + # Loop on models for training + for i, model in enumerate(models): + optimizers[i].zero_grad() + + # Forward pass + t1 = time.time() + output = model(input, style) + loss = loss_fns[i](output, target) + forward_time += time.time() - t1 + + # Backward pass and optimization + t2 = time.time() + loss.backward() + optimizers[i].step() + backward_time += time.time() - t2 + + train_loss_logs[i].append(loss.item()) + progress_bar.set_postfix(loss=f"{loss.item():2.5f}") + + prev_time = time.time() + + # End of epoch, validate the models + t3 = time.time() + for i, model in enumerate(models): + val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fns[i], device) + val_loss_logs[i].append(val_loss) + + if save_model_paths[i] is not None: + torch.save(model.state_dict(), save_model_paths[i] + f"_epoch_{epoch+1}.pth") + torch.save(dict(train_loss_log=train_loss_logs[i], + val_loss_log=val_loss_logs[i], + style_bins_means=style_bins_means, + style_bins=style_bins), + save_model_paths[i] + f"_epoch_{epoch+1}_stats.pth") + + if schedulers[i] is not None: + schedulers[i].step(val_loss) + validation_time += time.time() - t3 + + # Prepare new samples for the next epoch + dataloader['train'].dataset.on_epoch_end() + dataloader['val'].dataset.on_epoch_end() + + print() + print(f"================ Epoch {epoch+1} Summary ================") + for i, model in enumerate(models): + print(f"Model {i+1} Validation Loss: {val_loss_logs[i][-1]:2.6f}") + + if print_timers: + total_time = time.time() - epoch_start_time + print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s") + print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%") + print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%") + print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%") + print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%") + print() + + print("Training complete.") + return train_loss_logs, val_loss_logs