From 1798195db4321b17924b229a960ed6d1ee6d842b Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Wed, 25 Jun 2025 16:59:58 +0200 Subject: [PATCH] changed to train target_crop --- sCOCA_ML/train/train_gravpot.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/sCOCA_ML/train/train_gravpot.py b/sCOCA_ML/train/train_gravpot.py index b235c85..93d9677 100644 --- a/sCOCA_ML/train/train_gravpot.py +++ b/sCOCA_ML/train/train_gravpot.py @@ -13,7 +13,7 @@ def train_model(model, print_timers=False, save_model_path=None, scheduler=None, - shrink:bool = False): + target_crop:int = None): """ Train a model with the given dataloader and optimizer. @@ -26,7 +26,7 @@ def train_model(model, - print_timers: If True, print timing information for each epoch (default is False). - save_model_path: If provided, the model will be saved to this path after each epoch. - scheduler: Learning rate scheduler (optional). - - shrink: If True, the model will output only the center region of size N//2 of the target + - target_crop: If provided, the target will be cropped by this amount from each side (default is None, no cropping). Returns: - train_loss_log: List of training losses for each batch. @@ -67,10 +67,8 @@ def train_model(model, t1 = time.time() output = model(input, style) - if shrink: - # Shrink the target to the center region of size N//2 - N = target.shape[-1] - target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4] + if target_crop: + target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop] loss = loss_fn(output, target) forward_time += time.time() - t1 @@ -87,7 +85,7 @@ def train_model(model, # End of epoch, validate the model t3 = time.time() - val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, shrink=shrink) + val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, target_crop=target_crop) val_loss_log.append(val_loss) validation_time += time.time() - t3 @@ -128,7 +126,7 @@ def train_model(model, return train_loss_log, val_loss_log -def validate(model, val_loader, loss_fn, device='cuda', shrink:bool = False): +def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None): model.eval() losses = [] styles = [] @@ -141,10 +139,8 @@ def validate(model, val_loader, loss_fn, device='cuda', shrink:bool = False): target = batch['target'].to(device) style = batch['style'].to(device) - if shrink: - # Shrink the target to the center region of size N//2 - N = target.shape[-1] - target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4] + if target_crop: + target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop] output = model(input, style) loss = loss_fn(output, target)