mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-13 11:31:11 +00:00
fix 2D painting when input is (X , Y , 2) shape
This commit is contained in:
parent
20ace41d32
commit
d4049e5db4
1 changed files with 2 additions and 1 deletions
|
@ -134,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = positions.reshape([-1, 2])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
|
@ -142,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[..., jnp.newaxis]
|
||||
kernel = kernel * weight.reshape(*positions.shape[:-1])
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue