many improvements
This commit is contained in:
parent
c07ec8f8cf
commit
6c526d7115
4 changed files with 219 additions and 53 deletions
|
@ -36,7 +36,7 @@ class GravPotDataset(Dataset):
|
|||
N:int=128,
|
||||
N_full:int=768,
|
||||
match_str:str='train',
|
||||
device=torch.device('cpu'),
|
||||
device='cpu',
|
||||
initial_conditions_variables:tuple|list=['DM_delta', 'DM_phi'],
|
||||
target_variable:str='gravpot',
|
||||
style_files:str='cosmo_and_time_parameters',
|
||||
|
@ -50,6 +50,7 @@ class GravPotDataset(Dataset):
|
|||
- N: Size of the chunks to read (N x N x N).
|
||||
- N_full: Full size of the simulation box (N_full x N_full x N_full).
|
||||
- device: Device to load tensors onto (default is CPU)."""
|
||||
super().__init__()
|
||||
|
||||
self.initial_conditions_variables = initial_conditions_variables
|
||||
self.target_variable = target_variable
|
||||
|
@ -152,10 +153,44 @@ class GravPotDataset(Dataset):
|
|||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def files_from_samples(self, sample):
|
||||
"""
|
||||
Return the paths to the files for a given sample.
|
||||
"""
|
||||
ID, t, ox, oy, oz = sample
|
||||
input_paths = [
|
||||
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
||||
for var in self.initial_conditions_variables
|
||||
]
|
||||
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
||||
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
||||
return {
|
||||
'input': input_paths,
|
||||
'target': target_path,
|
||||
'style': style_path
|
||||
}
|
||||
|
||||
def files_from_ID_and_time(self, ID, t):
|
||||
"""
|
||||
Return the paths to the files for a given ID and time.
|
||||
"""
|
||||
input_paths = [
|
||||
os.path.join(self.root_dir, self.INITIAL_CONDITIONS_DIR, f'ICs_{ID}_{var}.h5')
|
||||
for var in self.initial_conditions_variables
|
||||
]
|
||||
target_path = os.path.join(self.root_dir, self.TARGET_DIR, f'{self.target_variable}_{ID}_nforce{t}.h5')
|
||||
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
||||
return {
|
||||
'input': input_paths,
|
||||
'target': target_path,
|
||||
'style': style_path
|
||||
}
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
from pysbmy.field import read_field_chunk_3D_periodic
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from sbmy_control.low_level import stdout_redirector, stderr_redirector
|
||||
f = BytesIO()
|
||||
|
||||
|
@ -170,12 +205,11 @@ class GravPotDataset(Dataset):
|
|||
style_path = os.path.join(self.root_dir, self.STYLE_DIR, f'{self.style_files}_{ID}_nforce{t}.txt')
|
||||
|
||||
# Read 3D chunks
|
||||
with stdout_redirector(f):
|
||||
input_arrays = [
|
||||
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
||||
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
||||
]
|
||||
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
||||
input_arrays = [
|
||||
read_field_chunk_3D_periodic(file, self.N,self.N,self.N, ox,oy,oz, name=varname).array
|
||||
for file, varname in zip(input_paths, self.initial_conditions_variables)
|
||||
]
|
||||
target_array = read_field_chunk_3D_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
|
||||
|
||||
# Stack the input arrays
|
||||
input_tensor = np.stack(input_arrays, axis=0)
|
||||
|
@ -206,21 +240,45 @@ class GravPotDataset(Dataset):
|
|||
def on_epoch_end(self):
|
||||
"""Call this at the end of each epoch to regenerate offset + time choices."""
|
||||
self._prepare_samples()
|
||||
|
||||
|
||||
class SubDataset(Dataset):
|
||||
def __init__(self, dataset: GravPotDataset, indices: list):
|
||||
self.dataset = dataset
|
||||
self.indices = indices
|
||||
def __init__(self, dataset: GravPotDataset, ID_list: list):
|
||||
from copy import deepcopy
|
||||
self.dataset = deepcopy(dataset)
|
||||
self.ids = ID_list
|
||||
self.dataset.ids = ID_list
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.dataset[self.indices[idx]]
|
||||
return self.dataset[idx]
|
||||
|
||||
def on_epoch_end(self):
|
||||
self.dataset.ids = self.ids
|
||||
self.dataset.on_epoch_end()
|
||||
|
||||
|
||||
|
||||
|
||||
def train_val_split(dataset: GravPotDataset, val_fraction: float = 0.2, seed: int = 42):
|
||||
"""
|
||||
Splits the dataset into training and validation sets.
|
||||
|
||||
Parameters:
|
||||
- dataset: The GravPotDataset to split.
|
||||
- val_fraction: Fraction of the dataset to use for validation.
|
||||
|
||||
Returns:
|
||||
- train_dataset: SubDataset for training.
|
||||
- val_dataset: SubDataset for validation.
|
||||
"""
|
||||
from sklearn.model_selection import train_test_split
|
||||
train_ids, val_ids = train_test_split(dataset.ids, test_size=0.2, random_state=seed)
|
||||
train_dataset = SubDataset(dataset, train_ids)
|
||||
val_dataset = SubDataset(dataset, val_ids)
|
||||
train_dataset.dataset._prepare_samples()
|
||||
val_dataset.dataset._prepare_samples()
|
||||
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
|
|
@ -30,6 +30,27 @@ class UNetBlock(nn.Module):
|
|||
x = self.film(x, style)
|
||||
return x
|
||||
|
||||
class UNetEncLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, style_dim=None):
|
||||
super(UNetEncLayer, self).__init__()
|
||||
self.block = UNetBlock(in_channels, out_channels, style_dim)
|
||||
self.pool = nn.MaxPool3d(2)
|
||||
|
||||
def forward(self, x, style=None):
|
||||
x = self.block(x, style)
|
||||
return x, self.pool(x)
|
||||
|
||||
class UNetDecLayer(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, skip_connection_channels, style_dim=None):
|
||||
super(UNetDecLayer, self).__init__()
|
||||
self.up = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
|
||||
self.block = UNetBlock(out_channels + skip_connection_channels, out_channels, style_dim)
|
||||
|
||||
def forward(self, x, skip_connection, style=None):
|
||||
x = self.up(x)
|
||||
x = torch.cat([x, skip_connection], dim=1)
|
||||
return self.block(x, style)
|
||||
|
||||
class UNet3D(BaseModel):
|
||||
def __init__(self, N: int = 128,
|
||||
in_channels: int = 2,
|
||||
|
@ -54,23 +75,51 @@ class UNet3D(BaseModel):
|
|||
out_channels=out_channels,
|
||||
style_parameters=style_dim,
|
||||
device=device)
|
||||
import numpy as np
|
||||
|
||||
self.enc1 = UNetBlock(in_channels, 32, style_dim)
|
||||
self.pool1 = nn.MaxPool3d(2)
|
||||
self.enc2 = UNetBlock(32, 64, style_dim)
|
||||
self.pool2 = nn.MaxPool3d(2)
|
||||
self.bottleneck = UNetBlock(64, 128, style_dim)
|
||||
self.depth = np.floor(np.log2(N)).astype(int) - 1 # Depth of the U-Net based on input size N
|
||||
self.first_layer_channel_exponent = 3
|
||||
|
||||
self.up2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
|
||||
self.dec2 = UNetBlock(128, 64)
|
||||
self.up1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
|
||||
self.dec1 = UNetBlock(64, 32)
|
||||
self.final = nn.Conv3d(32, out_channels, kernel_size=1)
|
||||
self.enc=[]
|
||||
|
||||
for i in range(self.depth):
|
||||
in_ch = in_channels if i == 0 else 2**(self.first_layer_channel_exponent + i - 1)
|
||||
out_ch = 2**(self.first_layer_channel_exponent + i)
|
||||
self.enc.append(UNetEncLayer(in_ch, out_ch, style_dim))
|
||||
|
||||
self.enc = nn.ModuleList(self.enc)
|
||||
|
||||
self.bottleneck = UNetBlock(2**(self.first_layer_channel_exponent + self.depth - 1),
|
||||
2**(self.first_layer_channel_exponent + self.depth), style_dim)
|
||||
|
||||
self.dec=[]
|
||||
|
||||
for i in range(self.depth - 1, -1, -1):
|
||||
in_ch = 2**(self.first_layer_channel_exponent + i + 1)
|
||||
out_ch = 2**(self.first_layer_channel_exponent + i)
|
||||
skip_conn_ch = out_ch
|
||||
self.dec.append(UNetDecLayer(in_ch, out_ch, skip_conn_ch, style_dim))
|
||||
|
||||
self.dec = nn.ModuleList(self.dec)
|
||||
|
||||
|
||||
self.final = nn.Conv3d(2**(self.first_layer_channel_exponent), out_channels, kernel_size=1)
|
||||
|
||||
|
||||
|
||||
def forward(self, x, style):
|
||||
e1 = self.enc1(x, style)
|
||||
e2 = self.enc2(self.pool1(e1), style)
|
||||
b = self.bottleneck(self.pool2(e2), style)
|
||||
d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
|
||||
d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
|
||||
return self.final(d1)
|
||||
|
||||
out = x
|
||||
outlist = []
|
||||
|
||||
for i in range(self.depth):
|
||||
skip, out = self.enc[i](out, style)
|
||||
outlist.append(skip)
|
||||
|
||||
out = self.bottleneck(out, style)
|
||||
|
||||
for i in range(self.depth):
|
||||
out = self.dec[i](out, outlist[self.depth - 1 - i], style)
|
||||
|
||||
return self.final(out)
|
||||
|
||||
|
|
|
@ -1,12 +1,23 @@
|
|||
def prepare_data(batch):
|
||||
def prepare_data(batch,
|
||||
scale_phi_ini:float = 1000.0,
|
||||
scale_delta_ini:float = 12.0,
|
||||
scale_target:float = 600.0,
|
||||
):
|
||||
|
||||
phi_ini = batch['input'][:, [1]]
|
||||
D1 = batch['style'][:, [0]]
|
||||
D2 = batch['style'][:, [1]]
|
||||
gravpot = batch['target'][:, [0]]
|
||||
# delta_ini = batch['input'][:, [0], :, :, :]
|
||||
phi_ini = batch['input'][:, [1], :, :, :]
|
||||
D1 = batch['style'][:, [0], None, None, None]
|
||||
# D2 = batch['style'][:, [1], None, None, None]
|
||||
gravpot = batch['target'][:, [0], :, :, :]
|
||||
|
||||
|
||||
_input = batch['input']
|
||||
_input[:, 0, :, :, :] /= scale_delta_ini
|
||||
_input[:, 1, :, :, :] /= scale_phi_ini
|
||||
|
||||
_target = (gravpot/D1 - phi_ini)/D1
|
||||
_target /= scale_target
|
||||
|
||||
_style = batch['style']
|
||||
|
||||
return {
|
||||
|
|
|
@ -5,12 +5,39 @@ import torch
|
|||
import time
|
||||
from ..prepare_data.prepare_gravpot_data import prepare_data
|
||||
|
||||
def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda', print_timers=False):
|
||||
def train_model(model,
|
||||
dataloader,
|
||||
optimizer=None,
|
||||
num_epochs=10,
|
||||
device='cuda',
|
||||
print_timers=False,
|
||||
save_model_path=None,
|
||||
scheduler=None):
|
||||
"""
|
||||
Train a model with the given dataloader and optimizer.
|
||||
|
||||
Parameters:
|
||||
- model: The model to train.
|
||||
- dataloader: A dictionary with 'train' and 'val' DataLoader objects.
|
||||
- optimizer: The optimizer to use for training (default is Adam with lr=1e-3).
|
||||
- num_epochs: Number of epochs to train the model (default is 10).
|
||||
- device: Device to run the model on (default is 'cuda').
|
||||
- print_timers: If True, print timing information for each epoch (default is False).
|
||||
- save_model_path: If provided, the model will be saved to this path after each epoch.
|
||||
- scheduler: Learning rate scheduler (optional).
|
||||
|
||||
Returns:
|
||||
- train_loss_log: List of training losses for each batch.
|
||||
- val_loss_log: List of validation losses for each epoch."""
|
||||
|
||||
if optimizer is None:
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
if scheduler is None:
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=num_epochs//4)
|
||||
model.to(device)
|
||||
loss_fn = torch.nn.MSELoss()
|
||||
loss_log = []
|
||||
train_loss_log = []
|
||||
val_loss_log = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
|
@ -20,7 +47,8 @@ def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda',
|
|||
backward_time = 0.0
|
||||
validation_time = 0.0
|
||||
|
||||
prev_time = time.time()
|
||||
epoch_start_time = time.time()
|
||||
prev_time = epoch_start_time # For I/O timing
|
||||
for batch in progress_bar:
|
||||
# I/O timer: time since last batch processed
|
||||
t0 = time.time()
|
||||
|
@ -45,32 +73,52 @@ def train_model(model, dataloader, optimizer=None, num_epochs=10, device='cuda',
|
|||
optimizer.step()
|
||||
backward_time += time.time() - t2
|
||||
|
||||
loss_log.append((style[:, 0].detach().cpu().numpy(), loss.item()))
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
train_loss_log.append(loss.item())
|
||||
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
||||
|
||||
prev_time = time.time() # End of loop, for next I/O timing
|
||||
|
||||
# End of epoch, validate the model
|
||||
t3 = time.time()
|
||||
val_loss, style_bins_means, style_bins = validate(model, dataloader['val'], loss_fn, device)
|
||||
val_loss_log.append(val_loss)
|
||||
validation_time += time.time() - t3
|
||||
|
||||
print(f"Validation Loss: {val_loss:.4f}")
|
||||
# Prepare new samples for the next epoch
|
||||
dataloader['train'].dataset.on_epoch_end()
|
||||
dataloader['val'].dataset.on_epoch_end()
|
||||
|
||||
if save_model_path is not None:
|
||||
torch.save(model.state_dict(), save_model_path+ f"_epoch_{epoch+1}.pth")
|
||||
torch.save(dict(train_loss_log=train_loss_log,
|
||||
val_loss_log=val_loss_log,
|
||||
style_bins_means=style_bins_means,
|
||||
style_bins=style_bins),
|
||||
save_model_path + f"_epoch_{epoch+1}_stats.pth")
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step(val_loss)
|
||||
|
||||
print()
|
||||
print(f"================ Epoch {epoch+1} Summary ================")
|
||||
print(f"Validation Loss: {val_loss:2.6f}")
|
||||
bin_width = max([len(f"{m:.2f}") for m in style_bins_means[:-1] + [2]]) # +[2] to avoid empty
|
||||
bins_str = "Style Bins: " + " | ".join([f"{b:>{bin_width}.2f}" for b in style_bins[:-1]])
|
||||
means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2f}" for m in style_bins_means])
|
||||
bins_str = "Style Bins: " + " | ".join([f" {b:>{bin_width}.2f} " for b in style_bins[:-1]])
|
||||
means_str = "Means: " + " | ".join([f"{m:>{bin_width}.2e}" for m in style_bins_means])
|
||||
print(bins_str)
|
||||
print(means_str)
|
||||
print()
|
||||
|
||||
if print_timers:
|
||||
total_time = io_time + forward_time + backward_time + validation_time
|
||||
print(f"Epoch {epoch+1} Timings:")
|
||||
print(f" I/O time: {io_time:.3f} s\t | {100 * io_time / total_time:.2f}%")
|
||||
print(f" Forward time: {forward_time:.3f} s\t | {100 * forward_time / total_time:.2f}%")
|
||||
print(f" Backward time: {backward_time:.3f} s\t | {100 * backward_time / total_time:.2f}%")
|
||||
print(f" Validation time: {validation_time:.3f} s\t | {100 * validation_time / total_time:.2f}%")
|
||||
total_time = time.time() - epoch_start_time
|
||||
print(f"Epoch {epoch+1} Timings: {total_time:9.0f} s")
|
||||
print(f" I/O time (train): {io_time:8.0f} s\t | {100 * io_time / total_time:2.2f}%")
|
||||
print(f" Forward time: {forward_time:8.0f} s\t | {100 * forward_time / total_time:2.2f}%")
|
||||
print(f" Backward time: {backward_time:8.0f} s\t | {100 * backward_time / total_time:2.2f}%")
|
||||
print(f" Validation time: {validation_time:8.0f} s\t | {100 * validation_time / total_time:2.2f}%")
|
||||
print()
|
||||
|
||||
return loss_log
|
||||
return train_loss_log, val_loss_log
|
||||
|
||||
|
||||
def validate(model, val_loader, loss_fn, device='cuda'):
|
||||
|
@ -89,11 +137,11 @@ def validate(model, val_loader, loss_fn, device='cuda'):
|
|||
output = model(input, style)
|
||||
loss = loss_fn(output, target)
|
||||
losses.append(loss.item())
|
||||
styles.append(style[:, 0].cpu().numpy())
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
styles.append(style[:, 0].cpu().numpy().mean()) # BEWARE: if batch size > 1, this will average the styles and make no sense
|
||||
progress_bar.set_postfix(loss=f"{loss.item():2.5f}")
|
||||
|
||||
# Bin loss by style[0]
|
||||
styles = np.concatenate(styles)
|
||||
styles = np.array(styles)
|
||||
losses = np.array(losses)
|
||||
bins = np.linspace(styles.min(), styles.max(), 10)
|
||||
digitized = np.digitize(styles, bins)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue