diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 4237c23..67d54b0 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -4,7 +4,7 @@ import jax.lax as lax from jaxpm.kernels import fftk, cic_compensation -def cic_paint(mesh, positions): +def cic_paint(mesh, positions, weight=None): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] @@ -18,7 +18,9 @@ def cic_paint(mesh, positions): neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - + if weight is not None: + kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) dnums = jax.lax.ScatterDimensionNumbers( @@ -93,4 +95,4 @@ def compensate_cic(field): delta_k = jnp.fft.rfftn(field) delta_k = cic_compensation(kvec) * delta_k - return jnp.fft.irfftn(delta_k) \ No newline at end of file + return jnp.fft.irfftn(delta_k)