import torch class Transformation: def __init__(self, transform=None, inverse_transform=None): self.transform = transform self.inverse_transform = inverse_transform def __call__(self, sample): if self.transform: return self.transform(sample) return sample def inverse(self, sample): if self.inverse_transform: return self.inverse_transform(sample) return sample def custom_symlog(x, threshold=1.): """ Custom symmetric logarithm function. Returns log(1 + |x| / threshold) * sign(x) * threshold. """ return torch.sign(x) * torch.log1p(torch.abs(x) / threshold) * threshold def custom_inverse_symlog(x, threshold=1.): """ Custom inverse symmetric logarithm function. Returns sign(x) * (exp(|x| / threshold) - 1) * threshold. """ return torch.sign(x) * (torch.exp(torch.abs(x) / threshold) - 1) * threshold def CustomSymlogTransform(threshold=1.): """ Custom transformation object for symmetric logarithm. """ return Transformation( transform=lambda x: custom_symlog(x, threshold), inverse_transform=lambda x: custom_inverse_symlog(x, threshold) )