diff --git a/jaxpm/painting.py b/jaxpm/painting.py index fd52106..c26f895 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -41,7 +41,7 @@ def cic_paint_impl(grid_mesh, positions, weight=None): return mesh -#@partial(jax.jit, static_argnums=(2, 3, 4)) +@partial(jax.jit, static_argnums=(2, 3, 4)) def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None): positions = positions.reshape((*grid_mesh.shape, 3))