From 88cff997367f85a5556fb6f60d727abd62f97601 Mon Sep 17 00:00:00 2001 From: EiffL Date: Tue, 17 May 2022 23:37:55 +0200 Subject: [PATCH] adds fix to make code jittablel --- jaxpm/lensing.py | 6 ++---- jaxpm/painting.py | 6 +++++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/jaxpm/lensing.py b/jaxpm/lensing.py index 5d3d1e2..d6c4a93 100644 --- a/jaxpm/lensing.py +++ b/jaxpm/lensing.py @@ -26,11 +26,9 @@ def density_plane(positions, xy = xy / nx * plane_resolution # Selecting only particles that fall inside the volume of interest - mask = (d > (center - width / 2)) & (d <= (center + width / 2)) - xy = xy[mask] - + weight = jnp.where((d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.) # Painting density plane - density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy) + density_plane = cic_paint_2d(jnp.zeros([plane_resolution, plane_resolution]), xy, weight) # Apply density normalization density_plane = density_plane / ((nx / plane_resolution) * diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 95bead9..9e323a2 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -1,3 +1,4 @@ +from tkinter import W import jax import jax.numpy as jnp import jax.lax as lax @@ -52,10 +53,11 @@ def cic_read(mesh, positions): neighboor_coords[...,1], neighboor_coords[...,3]]*kernel).sum(axis=-1) -def cic_paint_2d(mesh, positions): +def cic_paint_2d(mesh, positions, weight): """ Paints positions onto a 2d mesh mesh: [nx, ny] positions: [npart, 2] + weight: [npart] """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) @@ -64,6 +66,8 @@ def cic_paint_2d(mesh, positions): neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] + if weight is not None: + kernel = kernel * weight[...,jnp.newaxis] neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,4,2]).astype('int32'), jnp.array(mesh.shape))