symlog transform of the target

This commit is contained in:
Mayeul Aubin 2025-06-24 09:26:37 +02:00
parent 58f9b27e6e
commit 1abb0f9fb6
2 changed files with 50 additions and 3 deletions

View file

@ -1,9 +1,14 @@
from .utils_functions import *
def prepare_data(batch,
scale_phi_ini:float = 1000.0,
scale_delta_ini:float = 12.0,
scale_target:float = 600.0,
scale_phi_ini:float = 1000.0,
scale_delta_ini:float = 12.0,
scale_target:float = 50.0,
lin_threshold_target:float = 2.,
):
target_transform = CustomSymlogTransform(threshold=lin_threshold_target)
# delta_ini = batch['input'][:, [0], :, :, :]
phi_ini = batch['input'][:, [1], :, :, :]
D1 = batch['style'][:, [0], None, None, None]
@ -17,6 +22,7 @@ def prepare_data(batch,
_target = (gravpot/D1 - phi_ini)/D1
_target /= scale_target
_target = target_transform(_target)
_style = batch['style']