diff --git a/jaxpm/painting.py b/jaxpm/painting.py
index 37a86a9..838fe38 100644
--- a/jaxpm/painting.py
+++ b/jaxpm/painting.py
@@ -11,17 +11,11 @@ from jaxpm.kernels import cic_compensation, fftk
 from jaxpm.painting_utils import gather, scatter
 
 
-def cic_paint_impl(grid_mesh, displacement, weight=None):
+def cic_paint_impl(grid_mesh, positions, weight=None):
     """ Paints positions onto mesh
     mesh: [nx, ny, nz]
     displacement field: [nx, ny, nz, 3]
     """
-    part_shape = displacement.shape
-    positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
-                                       jnp.arange(part_shape[1]),
-                                       jnp.arange(part_shape[2]),
-                                       indexing='ij'),
-                          axis=-1) + displacement
     positions = positions.reshape([-1, 3])
     positions = jnp.expand_dims(positions, 1)
     floor = jnp.floor(positions)
@@ -47,7 +41,7 @@ def cic_paint_impl(grid_mesh, displacement, weight=None):
     return mesh
 
 
-@partial(jax.jit, static_argnums=(2, 3, 4))
+#@partial(jax.jit, static_argnums=(2, 3, 4))
 def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
 
     positions = positions.reshape((*grid_mesh.shape, 3))
@@ -66,43 +60,46 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
                               halo_periods=(True, True))
     grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
 
-    print(f"shape of grid_mesh: {grid_mesh.shape}")
     return grid_mesh
 
 
-def cic_read_impl(mesh, displacement):
+def cic_read_impl(grid_mesh, positions):
     """ Paints positions onto mesh
     mesh: [nx, ny, nz]
-    displacement: [nx,ny,nz, 3]
+    positions: [nx,ny,nz, 3]
     """
-    # Compute the position of the particles on a regular grid
-    part_shape = displacement.shape
-    positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]),
-                                       jnp.arange(part_shape[1]),
-                                       jnp.arange(part_shape[2]),
-                                       indexing='ij'),
-                          axis=-1) + displacement
+    # Save original shape for reshaping output later
+    original_shape = positions.shape
+    # Reshape positions to a flat list of 3D coordinates
     positions = positions.reshape([-1, 3])
+    # Expand dimensions to calculate neighbor coordinates
     positions = jnp.expand_dims(positions, 1)
+    # Floor the positions to get the base grid cell for each particle
     floor = jnp.floor(positions)
+    # Define connections to calculate all neighbor coordinates
     connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
                              [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
-
+    # Calculate the 8 neighboring coordinates
     neighboor_coords = floor + connection
+    # Calculate kernel weights based on distance from each neighboring coordinate
     kernel = 1. - jnp.abs(positions - neighboor_coords)
     kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
-
+    # Modulo operation to wrap around edges if necessary
     neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
-                               jnp.array(mesh.shape))
-
-    return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
-                 neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape(
-                     displacement.shape[:-1])
+                               jnp.array(grid_mesh.shape))
+    # Ensure grid_mesh shape is as expected
+    # Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
+    return (grid_mesh[neighboor_coords[..., 0],
+                      neighboor_coords[..., 1],
+                      neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
 
 
 @partial(jax.jit, static_argnums=(2, 3))
 def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
 
+    original_shape = positions.shape
+    positions = positions.reshape((*grid_mesh.shape, 3))
+
     halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
     grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
     grid_mesh = halo_exchange(grid_mesh,
@@ -114,9 +111,8 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
                              gpu_mesh=gpu_mesh,
                              in_specs=(spec, spec),
                              out_specs=spec)(grid_mesh, positions)
-    print(f"shape of displacement: {displacement.shape}")
 
-    return displacement
+    return displacement.reshape(original_shape[:-1])
 
 
 def cic_paint_2d(mesh, positions, weight):