From 6c526d71152b8fafcb59cd402f67ea84e318175d Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Tue, 17 Jun 2025 18:07:06 +0200 Subject: [PATCH] many improvements --- sCOCA_ML/dataset/gravpot_dataset.py | 84 +++++++++++++++--- sCOCA_ML/models/UNet_models.py | 81 +++++++++++++---- sCOCA_ML/prepare_data/prepare_gravpot_data.py | 21 +++-- sCOCA_ML/train/train_gravpot.py | 86 +++++++++++++++---- 4 files changed, 219 insertions(+), 53 deletions(-) diff --git a/sCOCA_ML/dataset/gravpot_dataset.py b/sCOCA_ML/dataset/gravpot_dataset.py index 04bc706..0329d15 100644 --- a/sCOCA_ML/dataset/gravpot_dataset.py +++ b/sCOCA_ML/dataset/gravpot_dataset.py @@ -36,7 +36,7 @@ class GravPotDataset(Dataset): N:int=128, N_full:int=768, match_str:str='train', - device=torch.device('cpu'), + device='cpu', initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'], target_variable:str='gravpot', style_files:str='cosmo_and_time_parameters', @@ -50,6 +50,7 @@ class GravPotDataset(Dataset): - N: Size of the chunks to read (N x N x N). - N_full: Full size of the simulation box (N_full x N_full x N_full). - device: Device to load tensors onto (default is CPU).""" + super().__init__() self.initial_conditions_variables = initial_conditions_variables self.target_variable = target_variable @@ -152,10 +153,44 @@ class GravPotDataset(Dataset): def __len__(self): return len(self.samples) + def files_from_samples(self, sample): + """ + Return the paths to the files for a given sample. + """ + ID, t, ox, oy, oz = sample + input_paths = [ + os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') + for var in self.initial_conditions_variables + ] + target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5') + style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') + return { + 'input': input_paths, + 'target': target_path, + 'style': style_path + } + + def files_from_ID_and_time(self, ID, t): + """ + Return the paths to the files for a given ID and time. + """ + input_paths = [ + os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5') + for var in self.initial_conditions_variables + ] + target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5') + style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') + return { + 'input': input_paths, + 'target': target_path, + 'style': style_path + } + def __getitem__(self, idx): from pysbmy.field import read_field_chunk_3D_periodic from io import BytesIO + import torch from sbmy_control.low_level import stdout_redirector, stderr_redirector f = BytesIO() @@ -170,12 +205,11 @@ class GravPotDataset(Dataset): style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt') # Read 3D chunks - with stdout_redirector(f): - input_arrays = [ - read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array - for file, varname in zip(input_paths, self.initial_conditions_variables) - ] - target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array + input_arrays = [ + read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array + for file, varname in zip(input_paths, self.initial_conditions_variables) + ] + target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array # Stack the input arrays input_tensor = np.stack(input_arrays, axis=0) @@ -206,21 +240,45 @@ class GravPotDataset(Dataset): def on_epoch_end(self): """Call this at the end of each epoch to regenerate offset + time choices.""" self._prepare_samples() - class SubDataset(Dataset): - def __init__(self, dataset: GravPotDataset, indices: list): - self.dataset = dataset - self.indices = indices + def __init__(self, dataset: GravPotDataset, ID_list: list): + from copy import deepcopy + self.dataset = deepcopy(dataset) + self.ids = ID_list + self.dataset.ids = ID_list def __len__(self): - return len(self.indices) + return len(self.dataset) def __getitem__(self, idx): - return self.dataset[self.indices[idx]] + return self.dataset[idx] def on_epoch_end(self): + self.dataset.ids = self.ids self.dataset.on_epoch_end() + + +def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: int = 42): + """ + Splits the dataset into training and validation sets. + + Parameters: + - dataset: The GravPotDataset to split. + - val_fraction: Fraction of the dataset to use for validation. + + Returns: + - train_dataset: SubDataset for training. + - val_dataset: SubDataset for validation. + """ + from sklearn.model_selection import train_test_split + train_ids, val_ids = train_test_split(dataset.ids, test_size=0.2, random_state=seed) + train_dataset = SubDataset(dataset, train_ids) + val_dataset = SubDataset(dataset, val_ids) + train_dataset.dataset._prepare_samples() + val_dataset.dataset._prepare_samples() + + return train_dataset, val_dataset diff --git a/sCOCA_ML/models/UNet_models.py b/sCOCA_ML/models/UNet_models.py index 02c3ec4..57dc489 100644 --- a/sCOCA_ML/models/UNet_models.py +++ b/sCOCA_ML/models/UNet_models.py @@ -30,6 +30,27 @@ class UNetBlock(nn.Module): x = self.film(x, style) return x +class UNetEncLayer(nn.Module): + def __init__(self, in_channels, out_channels, style_dim=None): + super(UNetEncLayer, self).__init__() + self.block = UNetBlock(in_channels, out_channels, style_dim) + self.pool = nn.MaxPool3d(2) + + def forward(self, x, style=None): + x = self.block(x, style) + return x, self.pool(x) + +class UNetDecLayer(nn.Module): + def __init__(self, in_channels, out_channels, skip_connection_channels, style_dim=None): + super(UNetDecLayer, self).__init__() + self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2) + self.block = UNetBlock(out_channels + skip_connection_channels, out_channels, style_dim) + + def forward(self, x, skip_connection, style=None): + x = self.up(x) + x = torch.cat([x, skip_connection], dim=1) + return self.block(x, style) + class UNet3D(BaseModel): def __init__(self, N: int = 128, in_channels: int = 2, @@ -54,23 +75,51 @@ class UNet3D(BaseModel): out_channels=out_channels, style_parameters=style_dim, device=device) + import numpy as np - self.enc1 = UNetBlock(in_channels, 32, style_dim) - self.pool1 = nn.MaxPool3d(2) - self.enc2 = UNetBlock(32, 64, style_dim) - self.pool2 = nn.MaxPool3d(2) - self.bottleneck = UNetBlock(64, 128, style_dim) + self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N + self.first_layer_channel_exponent = 3 - self.up2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2) - self.dec2 = UNetBlock(128, 64) - self.up1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2) - self.dec1 = UNetBlock(64, 32) - self.final = nn.Conv3d(32, out_channels, kernel_size=1) + self.enc=[] + + for i in range(self.depth): + in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1) + out_ch = 2**(self.first_layer_channel_exponent + i) + self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim)) + + self.enc = nn.ModuleList(self.enc) + + self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth - 1), + 2**(self.first_layer_channel_exponent + self.depth), style_dim) + + self.dec=[] + + for i in range(self.depth - 1, -1, -1): + in_ch = 2**(self.first_layer_channel_exponent + i + 1) + out_ch = 2**(self.first_layer_channel_exponent + i) + skip_conn_ch = out_ch + self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim)) + + self.dec = nn.ModuleList(self.dec) + + + self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1) + + def forward(self, x, style): - e1 = self.enc1(x, style) - e2 = self.enc2(self.pool1(e1), style) - b = self.bottleneck(self.pool2(e2), style) - d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1)) - d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1)) - return self.final(d1) + + out = x + outlist = [] + + for i in range(self.depth): + skip, out = self.enc[i](out, style) + outlist.append(skip) + + out = self.bottleneck(out, style) + + for i in range(self.depth): + out = self.dec[i](out, outlist[self.depth - 1 - i], style) + + return self.final(out) + diff --git a/sCOCA_ML/prepare_data/prepare_gravpot_data.py b/sCOCA_ML/prepare_data/prepare_gravpot_data.py index e4b7190..587bec0 100644 --- a/sCOCA_ML/prepare_data/prepare_gravpot_data.py +++ b/sCOCA_ML/prepare_data/prepare_gravpot_data.py @@ -1,12 +1,23 @@ -def prepare_data(batch): +def prepare_data(batch, + scale_phi_ini:float = 1000.0, + scale_delta_ini:float = 12.0, + scale_target:float = 600.0, + ): - phi_ini = batch['input'][:, [1]] - D1 = batch['style'][:, [0]] - D2 = batch['style'][:, [1]] - gravpot = batch['target'][:, [0]] + # delta_ini = batch['input'][:, [0], :, :, :] + phi_ini = batch['input'][:, [1], :, :, :] + D1 = batch['style'][:, [0], None, None, None] + # D2 = batch['style'][:, [1], None, None, None] + gravpot = batch['target'][:, [0], :, :, :] + _input = batch['input'] + _input[:, 0, :, :, :] /= scale_delta_ini + _input[:, 1, :, :, :] /= scale_phi_ini + _target = (gravpot/D1 - phi_ini)/D1 + _target /= scale_target + _style = batch['style'] return { diff --git a/sCOCA_ML/train/train_gravpot.py b/sCOCA_ML/train/train_gravpot.py index 722b21c..da6a024 100644 --- a/sCOCA_ML/train/train_gravpot.py +++ b/sCOCA_ML/train/train_gravpot.py @@ -5,12 +5,39 @@ import torch import time from ..prepare_data.prepare_gravpot_data import prepare_data -def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', print_timers=False): +def train_model(model, + dataloader, + optimizer=None, + num_epochs=10, + device='cuda', + print_timers=False, + save_model_path=None, + scheduler=None): + """ + Train a model with the given dataloader and optimizer. + + Parameters: + - model: The model to train. + - dataloader: A dictionary with 'train' and 'val' DataLoader objects. + - optimizer: The optimizer to use for training (default is Adam with lr=1e-3). + - num_epochs: Number of epochs to train the model (default is 10). + - device: Device to run the model on (default is 'cuda'). + - print_timers: If True, print timing information for each epoch (default is False). + - save_model_path: If provided, the model will be saved to this path after each epoch. + - scheduler: Learning rate scheduler (optional). + + Returns: + - train_loss_log: List of training losses for each batch. + - val_loss_log: List of validation losses for each epoch.""" + if optimizer is None: - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + if scheduler is None: + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//4) model.to(device) loss_fn = torch.nn.MSELoss() - loss_log = [] + train_loss_log = [] + val_loss_log = [] for epoch in range(num_epochs): model.train() @@ -20,7 +47,8 @@ def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', backward_time = 0.0 validation_time = 0.0 - prev_time = time.time() + epoch_start_time = time.time() + prev_time = epoch_start_time # For I/O timing for batch in progress_bar: # I/O timer: time since last batch processed t0 = time.time() @@ -45,32 +73,52 @@ def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', optimizer.step() backward_time += time.time() - t2 - loss_log.append((style[:, 0].detach().cpu().numpy(), loss.item())) - progress_bar.set_postfix(loss=loss.item()) + train_loss_log.append(loss.item()) + progress_bar.set_postfix(loss=f"{loss.item():2.5f}") prev_time = time.time() # End of loop, for next I/O timing # End of epoch, validate the model t3 = time.time() val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device) + val_loss_log.append(val_loss) validation_time += time.time() - t3 - print(f"Validation Loss: {val_loss:.4f}") + # Prepare new samples for the next epoch + dataloader['train'].dataset.on_epoch_end() + dataloader['val'].dataset.on_epoch_end() + + if save_model_path is not None: + torch.save(model.state_dict(), save_model_path+ f"_epoch_{epoch+1}.pth") + torch.save(dict(train_loss_log=train_loss_log, + val_loss_log=val_loss_log, + style_bins_means=style_bins_means, + style_bins=style_bins), + save_model_path + f"_epoch_{epoch+1}_stats.pth") + + if scheduler is not None: + scheduler.step(val_loss) + + print() + print(f"================ Epoch {epoch+1} Summary ================") + print(f"Validation Loss: {val_loss:2.6f}") bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty - bins_str = "Style Bins: " + " | ".join([f"{b:>{bin_width}.2f}" for b in style_bins[:-1]]) - means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2f}" for m in style_bins_means]) + bins_str = "Style Bins: " + " | ".join([f" {b:>{bin_width}.2f} " for b in style_bins[:-1]]) + means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2e}" for m in style_bins_means]) print(bins_str) print(means_str) + print() if print_timers: - total_time = io_time + forward_time + backward_time + validation_time - print(f"Epoch {epoch+1} Timings:") - print(f" I/O time: {io_time:.3f} s\t | {100 * io_time / total_time:.2f}%") - print(f" Forward time: {forward_time:.3f} s\t | {100 * forward_time / total_time:.2f}%") - print(f" Backward time: {backward_time:.3f} s\t | {100 * backward_time / total_time:.2f}%") - print(f" Validation time: {validation_time:.3f} s\t | {100 * validation_time / total_time:.2f}%") + total_time = time.time() - epoch_start_time + print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s") + print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%") + print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%") + print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%") + print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%") + print() - return loss_log + return train_loss_log, val_loss_log def validate(model, val_loader, loss_fn, device='cuda'): @@ -89,11 +137,11 @@ def validate(model, val_loader, loss_fn, device='cuda'): output = model(input, style) loss = loss_fn(output, target) losses.append(loss.item()) - styles.append(style[:, 0].cpu().numpy()) - progress_bar.set_postfix(loss=loss.item()) + styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense + progress_bar.set_postfix(loss=f"{loss.item():2.5f}") # Bin loss by style[0] - styles = np.concatenate(styles) + styles = np.array(styles) losses = np.array(losses) bins = np.linspace(styles.min(), styles.max(), 10) digitized = np.digitize(styles, bins)