def prepare_data(batch): phi_ini = batch['input'][:, [1]] D1 = batch['style'][:, [0]] D2 = batch['style'][:, [1]] gravpot = batch['target'][:, [0]] _input = batch['input'] _target = (gravpot/D1 - phi_ini)/D1 _style = batch['style'] return { 'input': _input, 'target': _target, 'style': _style }