From 835fa89aec01bd6f54c1c3631bd0203dad32c5da Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Fri, 5 May 2023 19:00:08 +0200 Subject: [PATCH] Adding option to have weights in the 3d cic paint --- jaxpm/painting.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 4237c23..67d54b0 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -4,7 +4,7 @@ import jax.lax as lax from jaxpm.kernels import fftk, cic_compensation -def cic_paint(mesh, positions): +def cic_paint(mesh, positions, weight=None): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] @@ -18,7 +18,9 @@ def cic_paint(mesh, positions): neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - + if weight is not None: + kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) dnums = jax.lax.ScatterDimensionNumbers( @@ -93,4 +95,4 @@ def compensate_cic(field): delta_k = jnp.fft.rfftn(field) delta_k = cic_compensation(kvec) * delta_k - return jnp.fft.irfftn(delta_k) \ No newline at end of file + return jnp.fft.irfftn(delta_k)