diff --git a/sCOCA_ML/dataset/momenta_dataset.py b/sCOCA_ML/dataset/momenta_dataset.py index a36b650..328269c 100644 --- a/sCOCA_ML/dataset/momenta_dataset.py +++ b/sCOCA_ML/dataset/momenta_dataset.py @@ -217,7 +217,6 @@ class MomentaDataset(Dataset): for file, varname in zip(input_paths, self.initial_conditions_variables) ] target_array = read_field_chunk_periodic(target_path, self.N, self.N, self.N, ox, oy, oz, name=self.target_variable).array - target_array = np.moveaxis(target_array, -1, 0) # Move channel dimension to the front # Stack the input arrays input_tensor = np.stack(input_arrays, axis=0)