changed to train target_crop

This commit is contained in:
Mayeul Aubin 2025-06-25 16:59:58 +02:00
parent 773f9ad860
commit 1798195db4

View file

@ -13,7 +13,7 @@ def train_model(model,
print_timers=False, print_timers=False,
save_model_path=None, save_model_path=None,
scheduler=None, scheduler=None,
shrink:bool = False): target_crop:int = None):
""" """
Train a model with the given dataloader and optimizer. Train a model with the given dataloader and optimizer.
@ -26,7 +26,7 @@ def train_model(model,
- print_timers: If True, print timing information for each epoch (default is False). - print_timers: If True, print timing information for each epoch (default is False).
- save_model_path: If provided, the model will be saved to this path after each epoch. - save_model_path: If provided, the model will be saved to this path after each epoch.
- scheduler: Learning rate scheduler (optional). - scheduler: Learning rate scheduler (optional).
- shrink: If True, the model will output only the center region of size N//2 of the target - target_crop: If provided, the target will be cropped by this amount from each side (default is None, no cropping).
Returns: Returns:
- train_loss_log: List of training losses for each batch. - train_loss_log: List of training losses for each batch.
@ -67,10 +67,8 @@ def train_model(model,
t1 = time.time() t1 = time.time()
output = model(input, style) output = model(input, style)
if shrink: if target_crop:
# Shrink the target to the center region of size N//2 target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
N = target.shape[-1]
target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4]
loss = loss_fn(output, target) loss = loss_fn(output, target)
forward_time += time.time() - t1 forward_time += time.time() - t1
@ -87,7 +85,7 @@ def train_model(model,
# End of epoch, validate the model # End of epoch, validate the model
t3 = time.time() t3 = time.time()
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, shrink=shrink) val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, target_crop=target_crop)
val_loss_log.append(val_loss) val_loss_log.append(val_loss)
validation_time += time.time() - t3 validation_time += time.time() - t3
@ -128,7 +126,7 @@ def train_model(model,
return train_loss_log, val_loss_log return train_loss_log, val_loss_log
def validate(model, val_loader, loss_fn, device='cuda', shrink:bool = False): def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None):
model.eval() model.eval()
losses = [] losses = []
styles = [] styles = []
@ -141,10 +139,8 @@ def validate(model, val_loader, loss_fn, device='cuda', shrink:bool = False):
target = batch['target'].to(device) target = batch['target'].to(device)
style = batch['style'].to(device) style = batch['style'].to(device)
if shrink: if target_crop:
# Shrink the target to the center region of size N//2 target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
N = target.shape[-1]
target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4]
output = model(input, style) output = model(input, style)
loss = loss_fn(output, target) loss = loss_fn(output, target)