diff --git a/jaxpm/pm.py b/jaxpm/pm.py
index 3155467..321bf0c 100644
--- a/jaxpm/pm.py
+++ b/jaxpm/pm.py
@@ -43,10 +43,8 @@ def pm_forces(positions,
     # Computes gravitational forces
     forces = jnp.stack([
         cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k),
-                    halo_size=halo_size,
-                    sharding=sharding) for i in range(3)
-    ],
-                       axis=-1)
+        halo_size=halo_size,
+        sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
 
     return forces
 
@@ -58,8 +56,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
     """
     gpu_mesh = sharding.mesh if sharding is not None else None
     spec = sharding.spec if sharding is not None else P()
-    local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding),
-                        3)
+    local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
     displacement = autoshmap(
       partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
       gpu_mesh=gpu_mesh,
@@ -88,7 +85,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
             # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
             # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
             nabla_i_nabla_i = gradient_kernel(kvec, i)**2
-            shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
+            shear_ii = fft3d(nabla_i_nabla_i * pot_k)
             delta2 += shear_ii * shear_acc
             shear_acc += shear_ii
 
@@ -98,7 +95,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
                 # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
                 nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
                     kvec, j)
-                delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
+                delta2 -= fft3d(nabla_i_nabla_j * pot_k)**2
 
         delta_k2 = fft3d(delta2)
         init_force2 = pm_forces(displacement,
@@ -191,16 +188,16 @@ def pgd_correction(pos, mesh_shape, params):
       pos: particle positions [npart, 3]
       params: [alpha, kl, ks] pgd parameters
     """
-    kvec = fftk(mesh_shape)
     delta = cic_paint(jnp.zeros(mesh_shape), pos)
+    delta_k = fft3d(delta)
+    kvec = fftk(delta_k)
     alpha, kl, ks = params
-    delta_k = jnp.fft.rfftn(delta)
     PGD_range = PGD_kernel(kvec, kl, ks)
 
     pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
 
     forces_pgd = jnp.stack([
-        cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
+        cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
         for i in range(3)
     ],
                            axis=-1)
@@ -217,11 +214,9 @@ def make_neural_ode_fn(model, mesh_shape):
         state is a tuple (position, velocities)
         """
         pos, vel = state
-        kvec = fftk(mesh_shape)
-
         delta = cic_paint(jnp.zeros(mesh_shape), pos)
-
-        delta_k = jnp.fft.rfftn(delta)
+        delta_k = fft3d(delta)
+        kvec = fftk(delta_k)
 
         # Computes gravitational potential
         pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
@@ -233,7 +228,7 @@ def make_neural_ode_fn(model, mesh_shape):
 
         # Computes gravitational forces
         forces = jnp.stack([
-            cic_read(jnp.fft.irfftn(-gradient_kernel(kvec, i) * pot_k), pos)
+            cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
             for i in range(3)
         ],
                            axis=-1)