add tests

This commit is contained in:
Wassim Kabalan 2024-12-06 18:56:57 +01:00
parent 36ef18e3d0
commit f70583b5fd
4 changed files with 483 additions and 0 deletions

10
tests/helpers.py Normal file
View file

@ -0,0 +1,10 @@
import jax.numpy as jnp
def MSE(x , y):
return jnp.mean((x - y)**2)
def MSE_3D(x , y):
return ((x - y)**2).mean(axis=0)
def MSRE(x , y):
return jnp.mean(((x - y)/ y)**2)