symlog transform of the target
This commit is contained in:
parent
58f9b27e6e
commit
1abb0f9fb6
2 changed files with 50 additions and 3 deletions
|
@ -1,9 +1,14 @@
|
||||||
|
from .utils_functions import *
|
||||||
|
|
||||||
def prepare_data(batch,
|
def prepare_data(batch,
|
||||||
scale_phi_ini:float = 1000.0,
|
scale_phi_ini:float = 1000.0,
|
||||||
scale_delta_ini:float = 12.0,
|
scale_delta_ini:float = 12.0,
|
||||||
scale_target:float = 600.0,
|
scale_target:float = 50.0,
|
||||||
|
lin_threshold_target:float = 2.,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
target_transform = CustomSymlogTransform(threshold=lin_threshold_target)
|
||||||
|
|
||||||
# delta_ini = batch['input'][:, [0], :, :, :]
|
# delta_ini = batch['input'][:, [0], :, :, :]
|
||||||
phi_ini = batch['input'][:, [1], :, :, :]
|
phi_ini = batch['input'][:, [1], :, :, :]
|
||||||
D1 = batch['style'][:, [0], None, None, None]
|
D1 = batch['style'][:, [0], None, None, None]
|
||||||
|
@ -17,6 +22,7 @@ def prepare_data(batch,
|
||||||
|
|
||||||
_target = (gravpot/D1 - phi_ini)/D1
|
_target = (gravpot/D1 - phi_ini)/D1
|
||||||
_target /= scale_target
|
_target /= scale_target
|
||||||
|
_target = target_transform(_target)
|
||||||
|
|
||||||
_style = batch['style']
|
_style = batch['style']
|
||||||
|
|
||||||
|
|
41
sCOCA_ML/prepare_data/utils_functions.py
Normal file
41
sCOCA_ML/prepare_data/utils_functions.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Transformation:
|
||||||
|
def __init__(self, transform=None, inverse_transform=None):
|
||||||
|
self.transform = transform
|
||||||
|
self.inverse_transform = inverse_transform
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
if self.transform:
|
||||||
|
return self.transform(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def inverse(self, sample):
|
||||||
|
if self.inverse_transform:
|
||||||
|
return self.inverse_transform(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def custom_symlog(x, threshold=1.):
|
||||||
|
"""
|
||||||
|
Custom symmetric logarithm function.
|
||||||
|
Returns log(1 + |x| / threshold) * sign(x) * threshold.
|
||||||
|
"""
|
||||||
|
return torch.sign(x) * torch.log1p(torch.abs(x) / threshold) * threshold
|
||||||
|
|
||||||
|
def custom_inverse_symlog(x, threshold=1.):
|
||||||
|
"""
|
||||||
|
Custom inverse symmetric logarithm function.
|
||||||
|
Returns sign(x) * (exp(|x| / threshold) - 1) * threshold.
|
||||||
|
"""
|
||||||
|
return torch.sign(x) * (torch.exp(torch.abs(x) / threshold) - 1) * threshold
|
||||||
|
|
||||||
|
|
||||||
|
def CustomSymlogTransform(threshold=1.):
|
||||||
|
"""
|
||||||
|
Custom transformation object for symmetric logarithm.
|
||||||
|
"""
|
||||||
|
return Transformation(
|
||||||
|
transform=lambda x: custom_symlog(x, threshold),
|
||||||
|
inverse_transform=lambda x: custom_inverse_symlog(x, threshold)
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue