diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index 0329d15..7315f57 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -186,16 +186,18 @@ class GravPotDataset(Dataset): 'style': style_path } - - def __getitem__(self, idx): + def get_data(self, ID, t, ox, oy, oz): + """ + Get the data for a specific ID, time, and offsets. + Returns a dictionary with input, target, and style tensors. + """ + from pysbmy.field import read_field_chunk_3D_periodic from io import BytesIO import torch from sbmy_control.low_level import stdout_redirector, stderr_redirector f = BytesIO() - ID, t, ox, oy, oz = self.samples[idx] - # Filepaths input_paths = [ os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') @@ -236,6 +238,15 @@ class GravPotDataset(Dataset): 'time': t, 'offset': (ox, oy, oz) } + + + + def __getitem__(self, idx): + + ID, t, ox, oy, oz = self.samples[idx] + return self.get_data(ID, t, ox, oy, oz) + + def on_epoch_end(self): """Call this at the end of each epoch to regenerate offset + time choices.""" diff --git a/sCOCA_ML/models/UNet_models.py b/sCOCA_ML/models/UNet_models.py index 57dc489..fcd8eef 100644 --- a/sCOCA_ML/models/UNet_models.py +++ b/sCOCA_ML/models/UNet_models.py @@ -56,7 +56,9 @@ class UNet3D(BaseModel): in_channels: int = 2, out_channels: int = 1, style_dim: int = 2, - device: torch.device = torch.device('cpu')): + device: torch.device = torch.device('cpu'), + first_layer_channel_exponent: int = 3, + ): """ 3D U-Net model with optional FiLM layers for style conditioning. Parameters: @@ -78,7 +80,7 @@ class UNet3D(BaseModel): import numpy as np self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N - self.first_layer_channel_exponent = 3 + self.first_layer_channel_exponent = first_layer_channel_exponent self.enc=[] diff --git a/sCOCA_ML/train/train_gravpot.py b/sCOCA_ML/train/train_gravpot.py index da6a024..b0c1141 100644 --- a/sCOCA_ML/train/train_gravpot.py +++ b/sCOCA_ML/train/train_gravpot.py @@ -33,7 +33,7 @@ def train_model(model, if optimizer is None: optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) if scheduler is None: - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//4) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5) model.to(device) loss_fn = torch.nn.MSELoss() train_loss_log = [] @@ -148,3 +148,136 @@ def validate(model, val_loader, loss_fn, device='cuda'): bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))] return losses.mean(), bin_means, bins + + + +def train_models(models, + dataloader, + optimizers=None, + num_epochs=10, + device='cuda', + print_timers=False, + save_model_paths=None, + schedulers=None): + """ + Train multiple models with their respective dataloaders and optimizers. + This is useful since the main bottelneck is I/O, so training multiple models on the same data loaded. + + Parameters: + - models: List of models to train. + - dataloader: Dictionnary with 'train' and 'val' DataLoader objects. + - optimizers: List of optimizers for each model (default is Adam with lr=1e-3). + - num_epochs: Number of epochs to train the models (default is 10). + - device: Device to run the models on (default is 'cuda'). + - print_timers: If True, print timing information for each epoch (default is False). + - save_model_paths: List of paths to save the models after each epoch. + - schedulers: List of learning rate schedulers for each model (optional). + + Returns: + - train_loss_logs: List of training losses for each model. + - val_loss_logs: List of validation losses for each model.""" + + if optimizers is None: + optimizers = [torch.optim.Adam(model.parameters(), lr=1e-4) for model in models] + + if schedulers is None: + schedulers = [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5) + for optimizer in optimizers] + + models = [model.to(device) for model in models] + + loss_fns = [torch.nn.MSELoss() for _ in models] + + train_loss_logs = [[] for _ in models] + val_loss_logs = [[] for _ in models] + + if save_model_paths is None: + save_model_paths = [None] * len(models) + + if len(save_model_paths) != len(models) or len(optimizers) != len(models) or len(schedulers) != len(models): + raise ValueError("Length of save_model_paths, optimizers, and schedulers must match the number of models.") + + print(f"Starting training for {len(models)} models...") + + for epoch in range(num_epochs): + + for model in models: + model.train() + + progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch') + io_time = 0.0 + forward_time = 0.0 + backward_time = 0.0 + validation_time = 0.0 + epoch_start_time = time.time() + prev_time = epoch_start_time + + for batch in progress_bar: + # I/O timer: time since last batch processed + t0 = time.time() + io_time += t0 - prev_time + + batch = prepare_data(batch) + input = batch['input'].to(device) + target = batch['target'].to(device) + style = batch['style'].to(device) + + # Loop on models for training + for i, model in enumerate(models): + optimizers[i].zero_grad() + + # Forward pass + t1 = time.time() + output = model(input, style) + loss = loss_fns[i](output, target) + forward_time += time.time() - t1 + + # Backward pass and optimization + t2 = time.time() + loss.backward() + optimizers[i].step() + backward_time += time.time() - t2 + + train_loss_logs[i].append(loss.item()) + progress_bar.set_postfix(loss=f"{loss.item():2.5f}") + + prev_time = time.time() + + # End of epoch, validate the models + t3 = time.time() + for i, model in enumerate(models): + val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fns[i], device) + val_loss_logs[i].append(val_loss) + + if save_model_paths[i] is not None: + torch.save(model.state_dict(), save_model_paths[i] + f"_epoch_{epoch+1}.pth") + torch.save(dict(train_loss_log=train_loss_logs[i], + val_loss_log=val_loss_logs[i], + style_bins_means=style_bins_means, + style_bins=style_bins), + save_model_paths[i] + f"_epoch_{epoch+1}_stats.pth") + + if schedulers[i] is not None: + schedulers[i].step(val_loss) + validation_time += time.time() - t3 + + # Prepare new samples for the next epoch + dataloader['train'].dataset.on_epoch_end() + dataloader['val'].dataset.on_epoch_end() + + print() + print(f"================ Epoch {epoch+1} Summary ================") + for i, model in enumerate(models): + print(f"Model {i+1} Validation Loss: {val_loss_logs[i][-1]:2.6f}") + + if print_timers: + total_time = time.time() - epoch_start_time + print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s") + print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%") + print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%") + print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%") + print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%") + print() + + print("Training complete.") + return train_loss_logs, val_loss_logs