From e9529d35f838e708c3e14c9cce5e176d66d07e28 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 30 Oct 2024 01:57:56 +0100 Subject: [PATCH] Weights can be traced --- jaxpm/painting.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index e4363ff..76bb9b6 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -46,8 +46,8 @@ def cic_paint_impl(grid_mesh, positions, weight=None): return mesh -@partial(jax.jit, static_argnums=(2, 3, 4)) -def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None): +@partial(jax.jit, static_argnums=(3, 4)) +def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): positions = positions.reshape((*grid_mesh.shape, 3)) @@ -112,6 +112,7 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None): halo_periods=(True, True)) gpu_mesh = sharding.mesh if sharding is not None else None spec = sharding.spec if sharding is not None else P() + displacement = autoshmap(cic_read_impl, gpu_mesh=gpu_mesh, in_specs=(spec, spec),