From 9b8e42aa00031cf622478bd77f5d11917b1a19f6 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Wed, 5 Feb 2025 14:26:28 +0100 Subject: [PATCH] updates for FullField --- jaxpm/painting.py | 35 ++++++++++++++++++++++++----------- jaxpm/utils.py | 2 +- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index b6b22a4..ef00fdb 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -140,26 +140,39 @@ def cic_paint_2d(mesh, positions, weight): positions: [npart, 2] weight: [npart] """ - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) - connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) + positions = positions.reshape([-1, 2]) + positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions) + floor = jax.tree.map(jnp.floor, positions) + connection = jnp.array([[[0, 0], [1., 0], [0., 1], [1., 1]]]) neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] - if weight is not None: - kernel = kernel * weight[..., jnp.newaxis] - neighboor_coords = jnp.mod( - neighboor_coords.reshape([-1, 4, 2]).astype('int32'), - jnp.array(mesh.shape)) + if weight is not None: + if jax.tree.all(jax.tree.map(jnp.isscalar, weight)): + kernel = jax.tree.map( + lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k), + kernel, weight) + else: + kernel = jax.tree.map( + lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k), + kernel, weight) + + neighboor_coords = jax.tree.map( + lambda nc: jnp.mod( + nc.reshape([-1, 4, 2]).astype('int32'), jnp.array(mesh.shape) + ), neighboor_coords) dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1), scatter_dims_to_operand_dims=(0, 1)) - mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 4]), - dnums) + mesh = jax.tree.map( + lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 4]), dnums), + mesh, neighboor_coords, kernel) + + return mesh diff --git a/jaxpm/utils.py b/jaxpm/utils.py index db33bb2..3c2cce1 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -221,4 +221,4 @@ def gaussian_smoothing(im, sigma): filter = norm.pdf(k, 0, 1. / (2. * np.pi * sigma)) filter /= filter[0, 0] - return jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real + return jax.tree.map(lambda im : jnp.fft.ifft2(jnp.fft.fft2(im) * filter).real , im)