forked from guilhem_lavaux/JaxPM
Adding option to have weights in the 3d cic paint
This commit is contained in:
parent
d4673d2955
commit
835fa89aec
1 changed files with 5 additions and 3 deletions
|
@ -4,7 +4,7 @@ import jax.lax as lax
|
|||
|
||||
from jaxpm.kernels import fftk, cic_compensation
|
||||
|
||||
def cic_paint(mesh, positions):
|
||||
def cic_paint(mesh, positions, weight=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
|
@ -18,7 +18,9 @@ def cic_paint(mesh, positions):
|
|||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
if weight is not None:
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape))
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(
|
||||
|
@ -93,4 +95,4 @@ def compensate_cic(field):
|
|||
|
||||
delta_k = jnp.fft.rfftn(field)
|
||||
delta_k = cic_compensation(kvec) * delta_k
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
|
|
Loading…
Add table
Reference in a new issue