fix 2D painting when input is (X , Y , 2) shape

This commit is contained in:
Wassim Kabalan 2025-05-09 22:07:53 +02:00
parent 20ace41d32
commit d4049e5db4

View file

@ -134,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight):
positions: [npart, 2]
weight: [npart]
"""
positions = positions.reshape([-1, 2])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -142,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight):
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[..., jnp.newaxis]
kernel = kernel * weight.reshape(*positions.shape[:-1])
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 4, 2]).astype('int32'),