From adaf7d236d8e236fa33a8fd505280318a15c758e Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Sun, 8 Dec 2024 22:54:52 +0100 Subject: [PATCH] fix deprecated FftType in jaxpm.kernels --- jaxpm/kernels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 7672093..912fe2f 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,6 +1,6 @@ import jax.numpy as jnp import numpy as np -from jax.lib.xla_client import FftType +from jax.lax import FftType from jax.sharding import PartitionSpec as P from jaxdecomp import fftfreq3d, get_output_specs