From 1abb0f9fb6dec494c917a70c5e2b31074575be0d Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Tue, 24 Jun 2025 09:26:37 +0200 Subject: [PATCH] symlog transform of the target --- sCOCA_ML/prepare_data/prepare_gravpot_data.py | 12 ++++-- sCOCA_ML/prepare_data/utils_functions.py | 41 +++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 sCOCA_ML/prepare_data/utils_functions.py diff --git a/sCOCA_ML/prepare_data/prepare_gravpot_data.py b/sCOCA_ML/prepare_data/prepare_gravpot_data.py index 587bec0..3c26ac0 100644 --- a/sCOCA_ML/prepare_data/prepare_gravpot_data.py +++ b/sCOCA_ML/prepare_data/prepare_gravpot_data.py @@ -1,9 +1,14 @@ +from .utils_functions import * + def prepare_data(batch, - scale_phi_ini:float = 1000.0, - scale_delta_ini:float = 12.0, - scale_target:float = 600.0, + scale_phi_ini:float = 1000.0, + scale_delta_ini:float = 12.0, + scale_target:float = 50.0, + lin_threshold_target:float = 2., ): + target_transform = CustomSymlogTransform(threshold=lin_threshold_target) + # delta_ini = batch['input'][:, [0], :, :, :] phi_ini = batch['input'][:, [1], :, :, :] D1 = batch['style'][:, [0], None, None, None] @@ -17,6 +22,7 @@ def prepare_data(batch, _target = (gravpot/D1 - phi_ini)/D1 _target /= scale_target + _target = target_transform(_target) _style = batch['style'] diff --git a/sCOCA_ML/prepare_data/utils_functions.py b/sCOCA_ML/prepare_data/utils_functions.py new file mode 100644 index 0000000..e2f2598 --- /dev/null +++ b/sCOCA_ML/prepare_data/utils_functions.py @@ -0,0 +1,41 @@ +import torch + +class Transformation: + def __init__(self, transform=None, inverse_transform=None): + self.transform = transform + self.inverse_transform = inverse_transform + + def __call__(self, sample): + if self.transform: + return self.transform(sample) + return sample + + def inverse(self, sample): + if self.inverse_transform: + return self.inverse_transform(sample) + return sample + + +def custom_symlog(x, threshold=1.): + """ + Custom symmetric logarithm function. + Returns log(1 + |x| / threshold) * sign(x) * threshold. + """ + return torch.sign(x) * torch.log1p(torch.abs(x) / threshold) * threshold + +def custom_inverse_symlog(x, threshold=1.): + """ + Custom inverse symmetric logarithm function. + Returns sign(x) * (exp(|x| / threshold) - 1) * threshold. + """ + return torch.sign(x) * (torch.exp(torch.abs(x) / threshold) - 1) * threshold + + +def CustomSymlogTransform(threshold=1.): + """ + Custom transformation object for symmetric logarithm. + """ + return Transformation( + transform=lambda x: custom_symlog(x, threshold), + inverse_transform=lambda x: custom_inverse_symlog(x, threshold) + )