improvements
This commit is contained in:
parent
6c526d7115
commit
58f9b27e6e
3 changed files with 153 additions and 7 deletions
|
@ -186,16 +186,18 @@ class GravPotDataset(Dataset):
|
|||
'style': style_path
|
||||
}
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def get_data(self, ID, t, ox, oy, oz):
|
||||
"""
|
||||
Get the data for a specific ID, time, and offsets.
|
||||
Returns a dictionary with input, target, and style tensors.
|
||||
"""
|
||||
|
||||
from pysbmy.field import read_field_chunk_3D_periodic
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from sbmy_control.low_level import stdout_redirector, stderr_redirector
|
||||
f = BytesIO()
|
||||
|
||||
ID, t, ox, oy, oz = self.samples[idx]
|
||||
|
||||
# Filepaths
|
||||
input_paths = [
|
||||
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
||||
|
@ -236,6 +238,15 @@ class GravPotDataset(Dataset):
|
|||
'time': t,
|
||||
'offset': (ox, oy, oz)
|
||||
}
|
||||
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
ID, t, ox, oy, oz = self.samples[idx]
|
||||
return self.get_data(ID, t, ox, oy, oz)
|
||||
|
||||
|
||||
|
||||
def on_epoch_end(self):
|
||||
"""Call this at the end of each epoch to regenerate offset + time choices."""
|
||||
|
|
|
@ -56,7 +56,9 @@ class UNet3D(BaseModel):
|
|||
in_channels: int = 2,
|
||||
out_channels: int = 1,
|
||||
style_dim: int = 2,
|
||||
device: torch.device = torch.device('cpu')):
|
||||
device: torch.device = torch.device('cpu'),
|
||||
first_layer_channel_exponent: int = 3,
|
||||
):
|
||||
"""
|
||||
3D U-Net model with optional FiLM layers for style conditioning.
|
||||
Parameters:
|
||||
|
@ -78,7 +80,7 @@ class UNet3D(BaseModel):
|
|||
import numpy as np
|
||||
|
||||
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
||||
self.first_layer_channel_exponent = 3
|
||||
self.first_layer_channel_exponent = first_layer_channel_exponent
|
||||
|
||||
self.enc=[]
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ def train_model(model,
|
|||
if optimizer is None:
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
if scheduler is None:
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//4)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
||||
model.to(device)
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
train_loss_log = []
|
||||
|
@ -148,3 +148,136 @@ def validate(model, val_loader, loss_fn, device='cuda'):
|
|||
bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))]
|
||||
|
||||
return losses.mean(), bin_means, bins
|
||||
|
||||
|
||||
|
||||
def train_models(models,
|
||||
dataloader,
|
||||
optimizers=None,
|
||||
num_epochs=10,
|
||||
device='cuda',
|
||||
print_timers=False,
|
||||
save_model_paths=None,
|
||||
schedulers=None):
|
||||
"""
|
||||
Train multiple models with their respective dataloaders and optimizers.
|
||||
This is useful since the main bottelneck is I/O, so training multiple models on the same data loaded.
|
||||
|
||||
Parameters:
|
||||
- models: List of models to train.
|
||||
- dataloader: Dictionnary with 'train' and 'val' DataLoader objects.
|
||||
- optimizers: List of optimizers for each model (default is Adam with lr=1e-3).
|
||||
- num_epochs: Number of epochs to train the models (default is 10).
|
||||
- device: Device to run the models on (default is 'cuda').
|
||||
- print_timers: If True, print timing information for each epoch (default is False).
|
||||
- save_model_paths: List of paths to save the models after each epoch.
|
||||
- schedulers: List of learning rate schedulers for each model (optional).
|
||||
|
||||
Returns:
|
||||
- train_loss_logs: List of training losses for each model.
|
||||
- val_loss_logs: List of validation losses for each model."""
|
||||
|
||||
if optimizers is None:
|
||||
optimizers = [torch.optim.Adam(model.parameters(), lr=1e-4) for model in models]
|
||||
|
||||
if schedulers is None:
|
||||
schedulers = [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//5)
|
||||
for optimizer in optimizers]
|
||||
|
||||
models = [model.to(device) for model in models]
|
||||
|
||||
loss_fns = [torch.nn.MSELoss() for _ in models]
|
||||
|
||||
train_loss_logs = [[] for _ in models]
|
||||
val_loss_logs = [[] for _ in models]
|
||||
|
||||
if save_model_paths is None:
|
||||
save_model_paths = [None] * len(models)
|
||||
|
||||
if len(save_model_paths) != len(models) or len(optimizers) != len(models) or len(schedulers) != len(models):
|
||||
raise ValueError("Length of save_model_paths, optimizers, and schedulers must match the number of models.")
|
||||
|
||||
print(f"Starting training for {len(models)} models...")
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
|
||||
for model in models:
|
||||
model.train()
|
||||
|
||||
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
|
||||
io_time = 0.0
|
||||
forward_time = 0.0
|
||||
backward_time = 0.0
|
||||
validation_time = 0.0
|
||||
epoch_start_time = time.time()
|
||||
prev_time = epoch_start_time
|
||||
|
||||
for batch in progress_bar:
|
||||
# I/O timer: time since last batch processed
|
||||
t0 = time.time()
|
||||
io_time += t0 - prev_time
|
||||
|
||||
batch = prepare_data(batch)
|
||||
input = batch['input'].to(device)
|
||||
target = batch['target'].to(device)
|
||||
style = batch['style'].to(device)
|
||||
|
||||
# Loop on models for training
|
||||
for i, model in enumerate(models):
|
||||
optimizers[i].zero_grad()
|
||||
|
||||
# Forward pass
|
||||
t1 = time.time()
|
||||
output = model(input, style)
|
||||
loss = loss_fns[i](output, target)
|
||||
forward_time += time.time() - t1
|
||||
|
||||
# Backward pass and optimization
|
||||
t2 = time.time()
|
||||
loss.backward()
|
||||
optimizers[i].step()
|
||||
backward_time += time.time() - t2
|
||||
|
||||
train_loss_logs[i].append(loss.item())
|
||||
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
||||
|
||||
prev_time = time.time()
|
||||
|
||||
# End of epoch, validate the models
|
||||
t3 = time.time()
|
||||
for i, model in enumerate(models):
|
||||
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fns[i], device)
|
||||
val_loss_logs[i].append(val_loss)
|
||||
|
||||
if save_model_paths[i] is not None:
|
||||
torch.save(model.state_dict(), save_model_paths[i] + f"_epoch_{epoch+1}.pth")
|
||||
torch.save(dict(train_loss_log=train_loss_logs[i],
|
||||
val_loss_log=val_loss_logs[i],
|
||||
style_bins_means=style_bins_means,
|
||||
style_bins=style_bins),
|
||||
save_model_paths[i] + f"_epoch_{epoch+1}_stats.pth")
|
||||
|
||||
if schedulers[i] is not None:
|
||||
schedulers[i].step(val_loss)
|
||||
validation_time += time.time() - t3
|
||||
|
||||
# Prepare new samples for the next epoch
|
||||
dataloader['train'].dataset.on_epoch_end()
|
||||
dataloader['val'].dataset.on_epoch_end()
|
||||
|
||||
print()
|
||||
print(f"================ Epoch {epoch+1} Summary ================")
|
||||
for i, model in enumerate(models):
|
||||
print(f"Model {i+1} Validation Loss: {val_loss_logs[i][-1]:2.6f}")
|
||||
|
||||
if print_timers:
|
||||
total_time = time.time() - epoch_start_time
|
||||
print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s")
|
||||
print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%")
|
||||
print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%")
|
||||
print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%")
|
||||
print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%")
|
||||
print()
|
||||
|
||||
print("Training complete.")
|
||||
return train_loss_logs, val_loss_logs
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue