From e188d5efb6e094fecf864fd888444f76adb4ff32 Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 11:19:56 +0200 Subject: [PATCH] adding function for doing 2d paintinng --- jaxpm/painting.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 27a1900..95bead9 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -52,6 +52,31 @@ def cic_read(mesh, positions): neighboor_coords[...,1], neighboor_coords[...,3]]*kernel).sum(axis=-1) +def cic_paint_2d(mesh, positions): + """ Paints positions onto a 2d mesh + mesh: [nx, ny] + positions: [npart, 2] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] + + neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1), + scatter_dims_to_operand_dims=(0, 1)) + mesh = lax.scatter_add(mesh, + neighboor_coords, + kernel.reshape([-1,4]), + dnums) + return mesh + def compensate_cic(field): """ Compensate for CiC painting