diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index ab85856..4fdb764 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -30,7 +30,7 @@ def autoshmap( def fft3d(x): - return jaxdecomp.pfft3d(x.astype(jnp.complex64)) + return jaxdecomp.pfft3d(x) def ifft3d(x):