41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
import torch
|
|
|
|
class Transformation:
|
|
def __init__(self, transform=None, inverse_transform=None):
|
|
self.transform = transform
|
|
self.inverse_transform = inverse_transform
|
|
|
|
def __call__(self, sample):
|
|
if self.transform:
|
|
return self.transform(sample)
|
|
return sample
|
|
|
|
def inverse(self, sample):
|
|
if self.inverse_transform:
|
|
return self.inverse_transform(sample)
|
|
return sample
|
|
|
|
|
|
def custom_symlog(x, threshold=1.):
|
|
"""
|
|
Custom symmetric logarithm function.
|
|
Returns log(1 + |x| / threshold) * sign(x) * threshold.
|
|
"""
|
|
return torch.sign(x) * torch.log1p(torch.abs(x) / threshold) * threshold
|
|
|
|
def custom_inverse_symlog(x, threshold=1.):
|
|
"""
|
|
Custom inverse symmetric logarithm function.
|
|
Returns sign(x) * (exp(|x| / threshold) - 1) * threshold.
|
|
"""
|
|
return torch.sign(x) * (torch.exp(torch.abs(x) / threshold) - 1) * threshold
|
|
|
|
|
|
def CustomSymlogTransform(threshold=1.):
|
|
"""
|
|
Custom transformation object for symmetric logarithm.
|
|
"""
|
|
return Transformation(
|
|
transform=lambda x: custom_symlog(x, threshold),
|
|
inverse_transform=lambda x: custom_inverse_symlog(x, threshold)
|
|
)
|