From 773f9ad8604459e76705353dc469c6ecbc31873c Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Wed, 25 Jun 2025 15:49:10 +0200 Subject: [PATCH] train shrink --- sCOCA_ML/train/train_gravpot.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/sCOCA_ML/train/train_gravpot.py b/sCOCA_ML/train/train_gravpot.py index b0c1141..b235c85 100644 --- a/sCOCA_ML/train/train_gravpot.py +++ b/sCOCA_ML/train/train_gravpot.py @@ -12,7 +12,8 @@ def train_model(model, device='cuda', print_timers=False, save_model_path=None, - scheduler=None): + scheduler=None, + shrink:bool = False): """ Train a model with the given dataloader and optimizer. @@ -25,6 +26,7 @@ def train_model(model, - print_timers: If True, print timing information for each epoch (default is False). - save_model_path: If provided, the model will be saved to this path after each epoch. - scheduler: Learning rate scheduler (optional). + - shrink: If True, the model will output only the center region of size N//2 of the target Returns: - train_loss_log: List of training losses for each batch. @@ -64,6 +66,11 @@ def train_model(model, # Forward pass t1 = time.time() output = model(input, style) + + if shrink: + # Shrink the target to the center region of size N//2 + N = target.shape[-1] + target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4] loss = loss_fn(output, target) forward_time += time.time() - t1 @@ -80,7 +87,7 @@ def train_model(model, # End of epoch, validate the model t3 = time.time() - val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device) + val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, shrink=shrink) val_loss_log.append(val_loss) validation_time += time.time() - t3 @@ -112,7 +119,7 @@ def train_model(model, if print_timers: total_time = time.time() - epoch_start_time print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s") - print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%") + print(f" Sync time + I/O: {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%") print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%") print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%") print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%") @@ -121,7 +128,7 @@ def train_model(model, return train_loss_log, val_loss_log -def validate(model, val_loader, loss_fn, device='cuda'): +def validate(model, val_loader, loss_fn, device='cuda', shrink:bool = False): model.eval() losses = [] styles = [] @@ -133,6 +140,11 @@ def validate(model, val_loader, loss_fn, device='cuda'): input = batch['input'].to(device) target = batch['target'].to(device) style = batch['style'].to(device) + + if shrink: + # Shrink the target to the center region of size N//2 + N = target.shape[-1] + target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4] output = model(input, style) loss = loss_fn(output, target)