fix deprecated FftType in jaxpm.kernels

This commit is contained in:
Wassim Kabalan 2024-12-08 22:54:52 +01:00
parent 5d34d3c3a8
commit adaf7d236d

View file

@ -1,6 +1,6 @@
import jax.numpy as jnp
import numpy as np
from jax.lib.xla_client import FftType
from jax.lax import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs