mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-23 01:57:10 +00:00
13 lines
184 B
Python
13 lines
184 B
Python
import jax.numpy as jnp
|
|
|
|
|
|
def MSE(x, y):
|
|
return ((x - y)**2).mean()
|
|
|
|
|
|
def MSE_3D(x, y):
|
|
return ((x - y)**2).mean(axis=0)
|
|
|
|
|
|
def MSRE(x, y):
|
|
return (((x - y) / y)**2).mean()
|