From d4049e5db46d811b6684021ee89642b8c1eab944 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Fri, 9 May 2025 22:07:53 +0200 Subject: [PATCH] fix 2D painting when input is (X , Y , 2) shape --- jaxpm/painting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 78d63ef..f3c50df 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -134,6 +134,7 @@ def cic_paint_2d(mesh, positions, weight): positions: [npart, 2] weight: [npart] """ + positions = positions.reshape([-1, 2]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) @@ -142,7 +143,7 @@ def cic_paint_2d(mesh, positions, weight): kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] if weight is not None: - kernel = kernel * weight[..., jnp.newaxis] + kernel = kernel * weight.reshape(*positions.shape[:-1]) neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 4, 2]).astype('int32'),