mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
jit cic_paint
This commit is contained in:
parent
d62c38f457
commit
b4fdb74660
1 changed files with 1 additions and 1 deletions
|
@ -41,7 +41,7 @@ def cic_paint_impl(grid_mesh, positions, weight=None):
|
||||||
return mesh
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
#@partial(jax.jit, static_argnums=(2, 3, 4))
|
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||||
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
||||||
|
|
||||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||||
|
|
Loading…
Add table
Reference in a new issue