jit cic_paint

This commit is contained in:
Wassim KABALAN 2024-10-27 03:48:50 +01:00
parent d62c38f457
commit b4fdb74660

View file

@ -41,7 +41,7 @@ def cic_paint_impl(grid_mesh, positions, weight=None):
return mesh
#@partial(jax.jit, static_argnums=(2, 3, 4))
@partial(jax.jit, static_argnums=(2, 3, 4))
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))