changed to train target_crop
This commit is contained in:
parent
773f9ad860
commit
1798195db4
1 changed files with 8 additions and 12 deletions
|
@ -13,7 +13,7 @@ def train_model(model,
|
||||||
print_timers=False,
|
print_timers=False,
|
||||||
save_model_path=None,
|
save_model_path=None,
|
||||||
scheduler=None,
|
scheduler=None,
|
||||||
shrink:bool = False):
|
target_crop:int = None):
|
||||||
"""
|
"""
|
||||||
Train a model with the given dataloader and optimizer.
|
Train a model with the given dataloader and optimizer.
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ def train_model(model,
|
||||||
- print_timers: If True, print timing information for each epoch (default is False).
|
- print_timers: If True, print timing information for each epoch (default is False).
|
||||||
- save_model_path: If provided, the model will be saved to this path after each epoch.
|
- save_model_path: If provided, the model will be saved to this path after each epoch.
|
||||||
- scheduler: Learning rate scheduler (optional).
|
- scheduler: Learning rate scheduler (optional).
|
||||||
- shrink: If True, the model will output only the center region of size N//2 of the target
|
- target_crop: If provided, the target will be cropped by this amount from each side (default is None, no cropping).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- train_loss_log: List of training losses for each batch.
|
- train_loss_log: List of training losses for each batch.
|
||||||
|
@ -67,10 +67,8 @@ def train_model(model,
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
|
|
||||||
if shrink:
|
if target_crop:
|
||||||
# Shrink the target to the center region of size N//2
|
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
||||||
N = target.shape[-1]
|
|
||||||
target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4]
|
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target)
|
||||||
forward_time += time.time() - t1
|
forward_time += time.time() - t1
|
||||||
|
|
||||||
|
@ -87,7 +85,7 @@ def train_model(model,
|
||||||
|
|
||||||
# End of epoch, validate the model
|
# End of epoch, validate the model
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, shrink=shrink)
|
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, target_crop=target_crop)
|
||||||
val_loss_log.append(val_loss)
|
val_loss_log.append(val_loss)
|
||||||
validation_time += time.time() - t3
|
validation_time += time.time() - t3
|
||||||
|
|
||||||
|
@ -128,7 +126,7 @@ def train_model(model,
|
||||||
return train_loss_log, val_loss_log
|
return train_loss_log, val_loss_log
|
||||||
|
|
||||||
|
|
||||||
def validate(model, val_loader, loss_fn, device='cuda', shrink:bool = False):
|
def validate(model, val_loader, loss_fn, device='cuda', target_crop:int = None):
|
||||||
model.eval()
|
model.eval()
|
||||||
losses = []
|
losses = []
|
||||||
styles = []
|
styles = []
|
||||||
|
@ -141,10 +139,8 @@ def validate(model, val_loader, loss_fn, device='cuda', shrink:bool = False):
|
||||||
target = batch['target'].to(device)
|
target = batch['target'].to(device)
|
||||||
style = batch['style'].to(device)
|
style = batch['style'].to(device)
|
||||||
|
|
||||||
if shrink:
|
if target_crop:
|
||||||
# Shrink the target to the center region of size N//2
|
target = target[..., target_crop:-target_crop, target_crop:-target_crop, target_crop:-target_crop]
|
||||||
N = target.shape[-1]
|
|
||||||
target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4]
|
|
||||||
|
|
||||||
output = model(input, style)
|
output = model(input, style)
|
||||||
loss = loss_fn(output, target)
|
loss = loss_fn(output, target)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue