diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 27a1900..95bead9 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -52,6 +52,31 @@ def cic_read(mesh, positions): neighboor_coords[...,1], neighboor_coords[...,3]]*kernel).sum(axis=-1) +def cic_paint_2d(mesh, positions): + """ Paints positions onto a 2d mesh + mesh: [nx, ny] + positions: [npart, 2] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] + + neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1)) + mesh = lax.scatter_add(mesh, + neighboor_coords, + kernel.reshape([-1,4]), + dnums) + return mesh + def compensate_cic(field): """ Compensate for CiC painting