ML_GravPotBCs/sCOCA_ML/prepare_data/prepare_momenta_data.py
2025-06-26 11:31:29 +02:00

29 lines
No EOL
788 B
Python

from .utils_functions import *
def prepare_data(batch,
scale_phi_ini:float = 1000.0,
scale_delta_ini:float = 12.0,
scale_target:float = 50.0,
lin_threshold_target:float = 2.,
):
# delta_ini = batch['input'][:, [0], :, :, :]
# phi_ini = batch['input'][:, [1], :, :, :]
a = batch['style'][:, [0], None, None, None]
momenta = batch['target'][:, [0], :, :, :]
_input = batch['input']
_input[:, 0, :, :, :] /= scale_delta_ini
_input[:, 1, :, :, :] /= scale_phi_ini
_target = momenta / (1.14*a)**(2.31)
_style = batch['style']
return {
'input': _input,
'target': _target,
'style': _style
}