Weights can be traced

This commit is contained in:
Wassim KABALAN 2024-10-30 01:57:56 +01:00
parent b09580d59e
commit e9529d35f8

View file

@ -46,8 +46,8 @@ def cic_paint_impl(grid_mesh, positions, weight=None):
return mesh
@partial(jax.jit, static_argnums=(2, 3, 4))
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
@ -112,6 +112,7 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
displacement = autoshmap(cic_read_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),