from .utils_functions import * def prepare_data(batch, scale_phi_ini:float = 1000.0, scale_delta_ini:float = 12.0, scale_target:float = 50.0, lin_threshold_target:float = 2., ): target_transform = CustomSymlogTransform(threshold=lin_threshold_target) # delta_ini = batch['input'][:, [0], :, :, :] phi_ini = batch['input'][:, [1], :, :, :] D1 = batch['style'][:, [0], None, None, None] # D2 = batch['style'][:, [1], None, None, None] gravpot = batch['target'][:, [0], :, :, :] _input = batch['input'] _input[:, 0, :, :, :] /= scale_delta_ini _input[:, 1, :, :, :] /= scale_phi_ini _target = (gravpot/D1 - phi_ini)/D1 _target /= scale_target _target = target_transform(_target) _style = batch['style'] return { 'input': _input, 'target': _target, 'style': _style }