import jax.numpy as jnp
def MSE(x, y):
return ((x - y)**2).mean()
def MSE_3D(x, y):
return ((x - y)**2).mean(axis=0)
def MSRE(x, y):
return (((x - y) / y)**2).mean()