expansion
This commit is contained in:
parent
24c2d546db
commit
c07ec8f8cf
6 changed files with 138 additions and 1 deletions
|
@ -206,3 +206,21 @@ class GravPotDataset(Dataset):
|
||||||
def on_epoch_end(self):
|
def on_epoch_end(self):
|
||||||
"""Call this at the end of each epoch to regenerate offset + time choices."""
|
"""Call this at the end of each epoch to regenerate offset + time choices."""
|
||||||
self._prepare_samples()
|
self._prepare_samples()
|
||||||
|
|
||||||
|
|
||||||
|
class SubDataset(Dataset):
|
||||||
|
def __init__(self, dataset: GravPotDataset, indices: list):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.indices = indices
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset[self.indices[idx]]
|
||||||
|
|
||||||
|
def on_epoch_end(self):
|
||||||
|
self.dataset.on_epoch_end()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class UNet3D(BaseModel):
|
||||||
The FiLM layers are used to condition the feature maps on style parameters.
|
The FiLM layers are used to condition the feature maps on style parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().init(N=N,
|
super().__init__(N=N,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
style_parameters=style_dim,
|
style_parameters=style_dim,
|
||||||
|
|
0
sCOCA_ML/prepare_data/__init__.py
Normal file
0
sCOCA_ML/prepare_data/__init__.py
Normal file
17
sCOCA_ML/prepare_data/prepare_gravpot_data.py
Normal file
17
sCOCA_ML/prepare_data/prepare_gravpot_data.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
def prepare_data(batch):
|
||||||
|
|
||||||
|
phi_ini = batch['input'][:, [1]]
|
||||||
|
D1 = batch['style'][:, [0]]
|
||||||
|
D2 = batch['style'][:, [1]]
|
||||||
|
gravpot = batch['target'][:, [0]]
|
||||||
|
|
||||||
|
_input = batch['input']
|
||||||
|
_target = (gravpot/D1 - phi_ini)/D1
|
||||||
|
_style = batch['style']
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input': _input,
|
||||||
|
'target': _target,
|
||||||
|
'style': _style
|
||||||
|
}
|
||||||
|
|
0
sCOCA_ML/train/__init__.py
Normal file
0
sCOCA_ML/train/__init__.py
Normal file
102
sCOCA_ML/train/train_gravpot.py
Normal file
102
sCOCA_ML/train/train_gravpot.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
from ..prepare_data.prepare_gravpot_data import prepare_data
|
||||||
|
|
||||||
|
def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', print_timers=False):
|
||||||
|
if optimizer is None:
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
|
model.to(device)
|
||||||
|
loss_fn = torch.nn.MSELoss()
|
||||||
|
loss_log = []
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
model.train()
|
||||||
|
progress_bar = tqdm(dataloader['train'], desc=f"Epoch {epoch+1}/{num_epochs}", unit='batch')
|
||||||
|
io_time = 0.0
|
||||||
|
forward_time = 0.0
|
||||||
|
backward_time = 0.0
|
||||||
|
validation_time = 0.0
|
||||||
|
|
||||||
|
prev_time = time.time()
|
||||||
|
for batch in progress_bar:
|
||||||
|
# I/O timer: time since last batch processed
|
||||||
|
t0 = time.time()
|
||||||
|
io_time += t0 - prev_time
|
||||||
|
|
||||||
|
batch = prepare_data(batch)
|
||||||
|
input = batch['input'].to(device)
|
||||||
|
target = batch['target'].to(device)
|
||||||
|
style = batch['style'].to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
t1 = time.time()
|
||||||
|
output = model(input, style)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
forward_time += time.time() - t1
|
||||||
|
|
||||||
|
# Backward pass and optimization
|
||||||
|
t2 = time.time()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
backward_time += time.time() - t2
|
||||||
|
|
||||||
|
loss_log.append((style[:, 0].detach().cpu().numpy(), loss.item()))
|
||||||
|
progress_bar.set_postfix(loss=loss.item())
|
||||||
|
|
||||||
|
prev_time = time.time() # End of loop, for next I/O timing
|
||||||
|
|
||||||
|
# End of epoch, validate the model
|
||||||
|
t3 = time.time()
|
||||||
|
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device)
|
||||||
|
validation_time += time.time() - t3
|
||||||
|
|
||||||
|
print(f"Validation Loss: {val_loss:.4f}")
|
||||||
|
bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty
|
||||||
|
bins_str = "Style Bins: " + " | ".join([f"{b:>{bin_width}.2f}" for b in style_bins[:-1]])
|
||||||
|
means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2f}" for m in style_bins_means])
|
||||||
|
print(bins_str)
|
||||||
|
print(means_str)
|
||||||
|
|
||||||
|
if print_timers:
|
||||||
|
total_time = io_time + forward_time + backward_time + validation_time
|
||||||
|
print(f"Epoch {epoch+1} Timings:")
|
||||||
|
print(f" I/O time: {io_time:.3f} s\t | {100 * io_time / total_time:.2f}%")
|
||||||
|
print(f" Forward time: {forward_time:.3f} s\t | {100 * forward_time / total_time:.2f}%")
|
||||||
|
print(f" Backward time: {backward_time:.3f} s\t | {100 * backward_time / total_time:.2f}%")
|
||||||
|
print(f" Validation time: {validation_time:.3f} s\t | {100 * validation_time / total_time:.2f}%")
|
||||||
|
|
||||||
|
return loss_log
|
||||||
|
|
||||||
|
|
||||||
|
def validate(model, val_loader, loss_fn, device='cuda'):
|
||||||
|
model.eval()
|
||||||
|
losses = []
|
||||||
|
styles = []
|
||||||
|
progress_bar = tqdm(val_loader, desc="Validation", unit='batch')
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in progress_bar:
|
||||||
|
batch = prepare_data(batch)
|
||||||
|
input = batch['input'].to(device)
|
||||||
|
target = batch['target'].to(device)
|
||||||
|
style = batch['style'].to(device)
|
||||||
|
|
||||||
|
output = model(input, style)
|
||||||
|
loss = loss_fn(output, target)
|
||||||
|
losses.append(loss.item())
|
||||||
|
styles.append(style[:, 0].cpu().numpy())
|
||||||
|
progress_bar.set_postfix(loss=loss.item())
|
||||||
|
|
||||||
|
# Bin loss by style[0]
|
||||||
|
styles = np.concatenate(styles)
|
||||||
|
losses = np.array(losses)
|
||||||
|
bins = np.linspace(styles.min(), styles.max(), 10)
|
||||||
|
digitized = np.digitize(styles, bins)
|
||||||
|
bin_means = [losses[digitized == i].mean() if np.any(digitized == i) else 0 for i in range(1, len(bins))]
|
||||||
|
|
||||||
|
return losses.mean(), bin_means, bins
|
Loading…
Add table
Add a link
Reference in a new issue