from .utils_functions import * def prepare_data(batch, scale_phi_ini:float = 1000.0, scale_delta_ini:float = 12.0, scale_target:float = 50.0, lin_threshold_target:float = 2., ): # delta_ini = batch['input'][:, [0], :, :, :] # phi_ini = batch['input'][:, [1], :, :, :] a = batch['style'][:, [0], None, None, None] momenta = batch['target'][:, [0], :, :, :] _input = batch['input'] _input[:, 0, :, :, :] /= scale_delta_ini _input[:, 1, :, :, :] /= scale_phi_ini _target = momenta / (1.14*a)**(2.31) _style = batch['style'] return { 'input': _input, 'target': _target, 'style': _style }