mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
Weights can be traced
This commit is contained in:
parent
b09580d59e
commit
e9529d35f8
1 changed files with 3 additions and 2 deletions
|
@ -46,8 +46,8 @@ def cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
return mesh
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
|
@ -112,6 +112,7 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
|||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
|
||||
displacement = autoshmap(cic_read_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec),
|
||||
|
|
Loading…
Add table
Reference in a new issue