diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 8bf6e2e..1a19b45 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -1,6 +1,6 @@ import numpy as np import jax.numpy as jnp -from scipy.stats import norm +from jax.scipy.stats import norm __all__ = ['power_spectrum']