forked from guilhem_lavaux/JaxPM
adds fix to make code jittablel
This commit is contained in:
parent
5dc239927f
commit
5108e56ee8
2 changed files with 7 additions and 5 deletions
|
@ -26,11 +26,9 @@ def density_plane(positions,
|
|||
xy = xy / nx * plane_resolution
|
||||
|
||||
# Selecting only particles that fall inside the volume of interest
|
||||
mask = (d > (center - width / 2)) & (d <= (center + width / 2))
|
||||
xy = xy[mask]
|
||||
|
||||
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
||||
# Painting density plane
|
||||
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy)
|
||||
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
||||
|
||||
# Apply density normalization
|
||||
density_plane = density_plane / ((nx / plane_resolution) *
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from tkinter import W
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.lax as lax
|
||||
|
@ -52,10 +53,11 @@ def cic_read(mesh, positions):
|
|||
neighboor_coords[...,1],
|
||||
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
||||
|
||||
def cic_paint_2d(mesh, positions):
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
""" Paints positions onto a 2d mesh
|
||||
mesh: [nx, ny]
|
||||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
|
@ -64,6 +66,8 @@ def cic_paint_2d(mesh, positions):
|
|||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1]
|
||||
if weight is not None:
|
||||
kernel = kernel * weight[...,jnp.newaxis]
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue