import jax.numpy as jnp def MSE(x, y): return jnp.mean((x - y)**2) def MSE_3D(x, y): return ((x - y)**2).mean(axis=0) def MSRE(x, y): return jnp.mean(((x - y) / y)**2)