train shrink
This commit is contained in:
parent
47cfefd61f
commit
773f9ad860
1 changed files with 16 additions and 4 deletions
|
@ -12,7 +12,8 @@ def train_model(model,
|
|||
device='cuda',
|
||||
print_timers=False,
|
||||
save_model_path=None,
|
||||
scheduler=None):
|
||||
scheduler=None,
|
||||
shrink:bool = False):
|
||||
"""
|
||||
Train a model with the given dataloader and optimizer.
|
||||
|
||||
|
@ -25,6 +26,7 @@ def train_model(model,
|
|||
- print_timers: If True, print timing information for each epoch (default is False).
|
||||
- save_model_path: If provided, the model will be saved to this path after each epoch.
|
||||
- scheduler: Learning rate scheduler (optional).
|
||||
- shrink: If True, the model will output only the center region of size N//2 of the target
|
||||
|
||||
Returns:
|
||||
- train_loss_log: List of training losses for each batch.
|
||||
|
@ -64,6 +66,11 @@ def train_model(model,
|
|||
# Forward pass
|
||||
t1 = time.time()
|
||||
output = model(input, style)
|
||||
|
||||
if shrink:
|
||||
# Shrink the target to the center region of size N//2
|
||||
N = target.shape[-1]
|
||||
target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4]
|
||||
loss = loss_fn(output, target)
|
||||
forward_time += time.time() - t1
|
||||
|
||||
|
@ -80,7 +87,7 @@ def train_model(model,
|
|||
|
||||
# End of epoch, validate the model
|
||||
t3 = time.time()
|
||||
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device)
|
||||
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device, shrink=shrink)
|
||||
val_loss_log.append(val_loss)
|
||||
validation_time += time.time() - t3
|
||||
|
||||
|
@ -112,7 +119,7 @@ def train_model(model,
|
|||
if print_timers:
|
||||
total_time = time.time() - epoch_start_time
|
||||
print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s")
|
||||
print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%")
|
||||
print(f" Sync time + I/O: {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%")
|
||||
print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%")
|
||||
print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%")
|
||||
print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%")
|
||||
|
@ -121,7 +128,7 @@ def train_model(model,
|
|||
return train_loss_log, val_loss_log
|
||||
|
||||
|
||||
def validate(model, val_loader, loss_fn, device='cuda'):
|
||||
def validate(model, val_loader, loss_fn, device='cuda', shrink:bool = False):
|
||||
model.eval()
|
||||
losses = []
|
||||
styles = []
|
||||
|
@ -133,6 +140,11 @@ def validate(model, val_loader, loss_fn, device='cuda'):
|
|||
input = batch['input'].to(device)
|
||||
target = batch['target'].to(device)
|
||||
style = batch['style'].to(device)
|
||||
|
||||
if shrink:
|
||||
# Shrink the target to the center region of size N//2
|
||||
N = target.shape[-1]
|
||||
target = target[..., N//4:N*3//4, N//4:N*3//4, N//4:N*3//4]
|
||||
|
||||
output = model(input, style)
|
||||
loss = loss_fn(output, target)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue