Add exp and log functions from torch
This commit is contained in:
parent
698b2a8df7
commit
ea61bb9b65
26
map2map/data/norms/torch.py
Normal file
26
map2map/data/norms/torch.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def exp(x, undo=False):
|
||||||
|
if not undo:
|
||||||
|
torch.exp(x, out=x)
|
||||||
|
else:
|
||||||
|
torch.log(x, out=x)
|
||||||
|
|
||||||
|
def log(x, eps=1e-8, undo=False):
|
||||||
|
if not undo:
|
||||||
|
torch.log(x + eps, out=x)
|
||||||
|
else:
|
||||||
|
torch.exp(x, out=x)
|
||||||
|
|
||||||
|
def expm1(x, undo=False):
|
||||||
|
if not undo:
|
||||||
|
torch.expm1(x, out=x)
|
||||||
|
else:
|
||||||
|
torch.log1p(x, out=x)
|
||||||
|
|
||||||
|
def log1p(x, eps=1e-7, undo=False):
|
||||||
|
if not undo:
|
||||||
|
torch.log1p(x + eps, out=x)
|
||||||
|
else:
|
||||||
|
torch.expm1(x, out=x)
|
Loading…
Reference in New Issue
Block a user