From b4fdb7466045bd99930594c9d418e01f4c3cd3a0 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Sun, 27 Oct 2024 03:48:50 +0100 Subject: [PATCH] jit cic_paint --- jaxpm/painting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index fd52106..c26f895 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -41,7 +41,7 @@ def cic_paint_impl(grid_mesh, positions, weight=None): return mesh -#@partial(jax.jit, static_argnums=(2, 3, 4)) +@partial(jax.jit, static_argnums=(2, 3, 4)) def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None): positions = positions.reshape((*grid_mesh.shape, 3))