forked from guilhem_lavaux/JaxPM
adds fix to make code jittablel
This commit is contained in:
parent
5dc239927f
commit
5108e56ee8
2 changed files with 7 additions and 5 deletions
|
@ -26,11 +26,9 @@ def density_plane(positions,
|
||||||
xy = xy / nx * plane_resolution
|
xy = xy / nx * plane_resolution
|
||||||
|
|
||||||
# Selecting only particles that fall inside the volume of interest
|
# Selecting only particles that fall inside the volume of interest
|
||||||
mask = (d > (center - width / 2)) & (d <= (center + width / 2))
|
weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
||||||
xy = xy[mask]
|
|
||||||
|
|
||||||
# Painting density plane
|
# Painting density plane
|
||||||
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy)
|
density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
||||||
|
|
||||||
# Apply density normalization
|
# Apply density normalization
|
||||||
density_plane = density_plane / ((nx / plane_resolution) *
|
density_plane = density_plane / ((nx / plane_resolution) *
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from tkinter import W
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
|
@ -52,10 +53,11 @@ def cic_read(mesh, positions):
|
||||||
neighboor_coords[...,1],
|
neighboor_coords[...,1],
|
||||||
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
neighboor_coords[...,3]]*kernel).sum(axis=-1)
|
||||||
|
|
||||||
def cic_paint_2d(mesh, positions):
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
""" Paints positions onto a 2d mesh
|
""" Paints positions onto a 2d mesh
|
||||||
mesh: [nx, ny]
|
mesh: [nx, ny]
|
||||||
positions: [npart, 2]
|
positions: [npart, 2]
|
||||||
|
weight: [npart]
|
||||||
"""
|
"""
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
|
@ -64,6 +66,8 @@ def cic_paint_2d(mesh, positions):
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1]
|
kernel = kernel[..., 0] * kernel[..., 1]
|
||||||
|
if weight is not None:
|
||||||
|
kernel = kernel * weight[...,jnp.newaxis]
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
|
neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue