From c07ec8f8cf4378a04e0da59e510b9dab19fdaa97 Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Fri, 6 Jun 2025 13:52:17 +0200 Subject: [PATCH] expansion --- sCOCA_ML/dataset/gravpot_dataset.py | 18 ++++ sCOCA_ML/models/UNet_models.py | 2 +- sCOCA_ML/prepare_data/__init__.py | 0 sCOCA_ML/prepare_data/prepare_gravpot_data.py | 17 +++ sCOCA_ML/train/__init__.py | 0 sCOCA_ML/train/train_gravpot.py | 102 ++++++++++++++++++ 6 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 sCOCA_ML/prepare_data/__init__.py create mode 100644 sCOCA_ML/prepare_data/prepare_gravpot_data.py create mode 100644 sCOCA_ML/train/__init__.py create mode 100644 sCOCA_ML/train/train_gravpot.py diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index d369801..04bc706 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -206,3 +206,21 @@ class GravPotDataset(Dataset): def on_epoch_end(self): """Call this at the end of each epoch to regenerate offset + time choices.""" self._prepare_samples() + + +class SubDataset(Dataset): + def __init__(self, dataset: GravPotDataset, indices: list): + self.dataset = dataset + self.indices = indices + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def on_epoch_end(self): + self.dataset.on_epoch_end() + + + diff --git a/sCOCA_ML/models/UNet_models.py b/sCOCA_ML/models/UNet_models.py index 56ee21e..02c3ec4 100644 --- a/sCOCA_ML/models/UNet_models.py +++ b/sCOCA_ML/models/UNet_models.py @@ -49,7 +49,7 @@ class UNet3D(BaseModel): The FiLM layers are used to condition the feature maps on style parameters. """ - super().init(N=N, + super().__init__(N=N, in_channels=in_channels, out_channels=out_channels, style_parameters=style_dim, diff --git a/sCOCA_ML/prepare_data/__init__.py b/sCOCA_ML/prepare_data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sCOCA_ML/prepare_data/prepare_gravpot_data.py b/sCOCA_ML/prepare_data/prepare_gravpot_data.py new file mode 100644 index 0000000..e4b7190 --- /dev/null +++ b/sCOCA_ML/prepare_data/prepare_gravpot_data.py @@ -0,0 +1,17 @@ +def prepare_data(batch): + + phi_ini = batch['input'][:, [1]] + D1 = batch['style'][:, [0]] + D2 = batch['style'][:, [1]] + gravpot = batch['target'][:, [0]] + + _input = batch['input'] + _target = (gravpot/D1 - phi_ini)/D1 + _style = batch['style'] + + return { + 'input': _input, + 'target': _target, + 'style': _style + } + \ No newline at end of file diff --git a/sCOCA_ML/train/__init__.py b/sCOCA_ML/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sCOCA_ML/train/train_gravpot.py b/sCOCA_ML/train/train_gravpot.py new file mode 100644 index 0000000..722b21c --- /dev/null +++ b/sCOCA_ML/train/train_gravpot.py @@ -0,0 +1,102 @@ +import numpy as np +import matplotlib.pyplot as plt +from tqdm import tqdm +import torch +import time +from ..prepare_data.prepare_gravpot_data import prepare_data + +def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', print_timers=False): + if optimizer is None: + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + model.to(device) + loss_fn = torch.nn.MSELoss() + loss_log = [] + + for epoch in range(num_epochs): + model.train() + progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch') + io_time = 0.0 + forward_time = 0.0 + backward_time = 0.0 + validation_time = 0.0 + + prev_time = time.time() + for batch in progress_bar: + # I/O timer: time since last batch processed + t0 = time.time() + io_time += t0 - prev_time + + batch = prepare_data(batch) + input = batch['input'].to(device) + target = batch['target'].to(device) + style = batch['style'].to(device) + + optimizer.zero_grad() + + # Forward pass + t1 = time.time() + output = model(input, style) + loss = loss_fn(output, target) + forward_time += time.time() - t1 + + # Backward pass and optimization + t2 = time.time() + loss.backward() + optimizer.step() + backward_time += time.time() - t2 + + loss_log.append((style[:, 0].detach().cpu().numpy(), loss.item())) + progress_bar.set_postfix(loss=loss.item()) + + prev_time = time.time() # End of loop, for next I/O timing + + # End of epoch, validate the model + t3 = time.time() + val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device) + validation_time += time.time() - t3 + + print(f"Validation Loss: {val_loss:.4f}") + bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty + bins_str = "Style Bins: " + " | ".join([f"{b:>{bin_width}.2f}" for b in style_bins[:-1]]) + means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2f}" for m in style_bins_means]) + print(bins_str) + print(means_str) + + if print_timers: + total_time = io_time + forward_time + backward_time + validation_time + print(f"Epoch {epoch+1} Timings:") + print(f" I/O time: {io_time:.3f} s\t | {100 * io_time / total_time:.2f}%") + print(f" Forward time: {forward_time:.3f} s\t | {100 * forward_time / total_time:.2f}%") + print(f" Backward time: {backward_time:.3f} s\t | {100 * backward_time / total_time:.2f}%") + print(f" Validation time: {validation_time:.3f} s\t | {100 * validation_time / total_time:.2f}%") + + return loss_log + + +def validate(model, val_loader, loss_fn, device='cuda'): + model.eval() + losses = [] + styles = [] + progress_bar = tqdm(val_loader, desc="Validation", unit='batch') + + with torch.no_grad(): + for batch in progress_bar: + batch = prepare_data(batch) + input = batch['input'].to(device) + target = batch['target'].to(device) + style = batch['style'].to(device) + + output = model(input, style) + loss = loss_fn(output, target) + losses.append(loss.item()) + styles.append(style[:, 0].cpu().numpy()) + progress_bar.set_postfix(loss=loss.item()) + + # Bin loss by style[0] + styles = np.concatenate(styles) + losses = np.array(losses) + bins = np.linspace(styles.min(), styles.max(), 10) + digitized = np.digitize(styles, bins) + bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))] + + return losses.mean(), bin_means, bins