mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
Allow applying weights with relative cic paint
This commit is contained in:
parent
b3a264ad53
commit
b09580d59e
2 changed files with 28 additions and 12 deletions
|
@ -104,8 +104,7 @@ def _scatter_chunk(carry, chunk):
|
|||
spatial_shape)
|
||||
# scatter
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
mesh = mesh.at[ind].add(val * frac)
|
||||
|
||||
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
|
||||
carry = mesh, offset, cell_size, mesh_shape
|
||||
return carry, None
|
||||
|
||||
|
@ -117,11 +116,9 @@ def scatter(pmid,
|
|||
val=1.,
|
||||
offset=0,
|
||||
cell_size=1.):
|
||||
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
val = jnp.asarray(val)
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
if remainder is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue