adds fix to make code jittablel

This commit is contained in:
EiffL 2022-05-17 23:37:55 +02:00
parent 5dc239927f
commit 5108e56ee8
2 changed files with 7 additions and 5 deletions

View file

@ -1,3 +1,4 @@
from tkinter import W
import jax
import jax.numpy as jnp
import jax.lax as lax
@ -52,10 +53,11 @@ def cic_read(mesh, positions):
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions):
def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
@ -64,6 +66,8 @@ def cic_paint_2d(mesh, positions):
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[...,jnp.newaxis]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))