diff --git a/jaxpm/painting.py b/jaxpm/painting.py index f3c50df..78d63ef 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -134,7 +134,6 @@ def cic_paint_2d(mesh, positions, weight): positions: [npart, 2] weight: [npart] """ - positions = positions.reshape([-1, 2]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) @@ -143,7 +142,7 @@ def cic_paint_2d(mesh, positions, weight): kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] if weight is not None: - kernel = kernel * weight.reshape(*positions.shape[:-1]) + kernel = kernel * weight[..., jnp.newaxis] neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 4, 2]).astype('int32'),