bugfix rank momenta

This commit is contained in:
Mayeul Aubin 2025-07-04 10:58:42 +02:00
parent 75a3e5151e
commit 6d876485f7

View file

@ -217,7 +217,6 @@ class MomentaDataset(Dataset):
for file, varname in zip(input_paths, self.initial_conditions_variables) for file, varname in zip(input_paths, self.initial_conditions_variables)
] ]
target_array = read_field_chunk_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array target_array = read_field_chunk_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array
target_array = np.moveaxis(target_array, -1, 0) # Move channel dimension to the front
# Stack the input arrays # Stack the input arrays
input_tensor = np.stack(input_arrays, axis=0) input_tensor = np.stack(input_arrays, axis=0)