Adding option to have weights in the 3d cic paint

This commit is contained in:
Francois Lanusse 2023-05-05 19:00:08 +02:00 committed by GitHub
parent d4673d2955
commit 835fa89aec

View file

@ -4,7 +4,7 @@ import jax.lax as lax
from jaxpm.kernels import fftk, cic_compensation from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions): def cic_paint(mesh, positions, weight=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
@ -18,7 +18,9 @@ def cic_paint(mesh, positions):
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers( dnums = jax.lax.ScatterDimensionNumbers(
@ -93,4 +95,4 @@ def compensate_cic(field):
delta_k = jnp.fft.rfftn(field) delta_k = jnp.fft.rfftn(field)
delta_k = cic_compensation(kvec) * delta_k delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k) return jnp.fft.irfftn(delta_k)