mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 04:10:54 +00:00
13 lines
186 B
Python
13 lines
186 B
Python
import jax.numpy as jnp
|
|
|
|
|
|
def MSE(x, y):
|
|
return jnp.mean((x - y)**2)
|
|
|
|
|
|
def MSE_3D(x, y):
|
|
return ((x - y)**2).mean(axis=0)
|
|
|
|
|
|
def MSRE(x, y):
|
|
return jnp.mean(((x - y) / y)**2)
|