adds fix to make code jittablel

This commit is contained in:
EiffL 2022-05-17 23:37:55 +02:00
parent 5dc239927f
commit 5108e56ee8
2 changed files with 7 additions and 5 deletions

View file

@ -26,11 +26,9 @@ def density_plane(positions,
xy = xy / nx * plane_resolution
# Selecting only particles that fall inside the volume of interest
mask = (d > (center - width / 2)) & (d <= (center + width / 2))
xy = xy[mask]
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
# Painting density plane
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy)
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *

View file

@ -1,3 +1,4 @@
from tkinter import W
import jax
import jax.numpy as jnp
import jax.lax as lax
@ -52,10 +53,11 @@ def cic_read(mesh, positions):
neighboor_coords[...,1],
neighboor_coords[...,3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions):
def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
@ -64,6 +66,8 @@ def cic_paint_2d(mesh, positions):
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1]
if weight is not None:
kernel = kernel * weight[...,jnp.newaxis]
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))