ML_GravPotBCs/sCOCA_ML/prepare_data/utils_functions.py

41 lines
1.2 KiB
Python

import torch
class Transformation:
def __init__(self, transform=None, inverse_transform=None):
self.transform = transform
self.inverse_transform = inverse_transform
def __call__(self, sample):
if self.transform:
return self.transform(sample)
return sample
def inverse(self, sample):
if self.inverse_transform:
return self.inverse_transform(sample)
return sample
def custom_symlog(x, threshold=1.):
"""
Custom symmetric logarithm function.
Returns log(1 + |x| / threshold) * sign(x) * threshold.
"""
return torch.sign(x) * torch.log1p(torch.abs(x) / threshold) * threshold
def custom_inverse_symlog(x, threshold=1.):
"""
Custom inverse symmetric logarithm function.
Returns sign(x) * (exp(|x| / threshold) - 1) * threshold.
"""
return torch.sign(x) * (torch.exp(torch.abs(x) / threshold) - 1) * threshold
def CustomSymlogTransform(threshold=1.):
"""
Custom transformation object for symmetric logarithm.
"""
return Transformation(
transform=lambda x: custom_symlog(x, threshold),
inverse_transform=lambda x: custom_inverse_symlog(x, threshold)
)