diff --git a/jaxpm/painting.py b/jaxpm/painting.py index e4363ff..76bb9b6 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -46,8 +46,8 @@ def cic_paint_impl(grid_mesh, positions, weight=None): return mesh -@partial(jax.jit, static_argnums=(2, 3, 4)) -def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None): +@partial(jax.jit, static_argnums=(3, 4)) +def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): positions = positions.reshape((*grid_mesh.shape, 3)) @@ -112,6 +112,7 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None): halo_periods=(True, True)) gpu_mesh = sharding.mesh if sharding is not None else None spec = sharding.spec if sharding is not None else P() + displacement = autoshmap(cic_read_impl, gpu_mesh=gpu_mesh, in_specs=(spec, spec),