mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
jit cic_paint
This commit is contained in:
parent
d62c38f457
commit
b4fdb74660
1 changed files with 1 additions and 1 deletions
|
@ -41,7 +41,7 @@ def cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
return mesh
|
||||
|
||||
|
||||
#@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
|
Loading…
Add table
Reference in a new issue