From 91d3292923c8b7cd4cb93b678a3fb4a63fb3ec32 Mon Sep 17 00:00:00 2001
From: Wassim Kabalan <wastondev@gmail.com>
Date: Fri, 28 Feb 2025 13:46:41 +0100
Subject: [PATCH] add correct annotations for weights in painting and warning
 for cic_paint in distributed pm

---
 jaxpm/painting.py | 33 ++++++++++++++++++++-------------
 1 file changed, 20 insertions(+), 13 deletions(-)

diff --git a/jaxpm/painting.py b/jaxpm/painting.py
index 3083f08..f8797f2 100644
--- a/jaxpm/painting.py
+++ b/jaxpm/painting.py
@@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk
 from jaxpm.painting_utils import gather, scatter
 
 
-def _cic_paint_impl(grid_mesh, positions, weight=None):
+def _cic_paint_impl(grid_mesh, positions, weight=1.):
     """ Paints positions onto mesh
     mesh: [nx, ny, nz]
     displacement field: [nx, ny, nz, 3]
@@ -27,12 +27,11 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
     neighboor_coords = floor + connection
     kernel = 1. - jnp.abs(positions - neighboor_coords)
     kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
-    if weight is not None:
-        if jnp.isscalar(weight):
-            kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
-        else:
-            kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
-                                  kernel)
+    if jnp.isscalar(weight):
+        kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
+    else:
+        kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
+                                kernel)
 
     neighboor_coords = jnp.mod(
         neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@@ -48,7 +47,13 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
 
 
 @partial(jax.jit, static_argnums=(3, 4))
-def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
+def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None):
+
+    if sharding is not None:
+        print("""
+            WARNING : absolute painting is not recommended in multi-device mode.
+            Please use relative painting instead.
+            """)
 
     positions = positions.reshape((*grid_mesh.shape, 3))
 
@@ -57,9 +62,11 @@ def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
 
     gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
     spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
+    weight_spec = P() if jnp.isscalar(weight) else spec
+
     grid_mesh = autoshmap(_cic_paint_impl,
                           gpu_mesh=gpu_mesh,
-                          in_specs=(spec, spec, P()),
+                          in_specs=(spec, spec, weight_spec),
                           out_specs=spec)(grid_mesh, positions, weight)
     grid_mesh = halo_exchange(grid_mesh,
                               halo_extents=halo_extents,
@@ -151,7 +158,7 @@ def cic_paint_2d(mesh, positions, weight):
     return mesh
 
 
-def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
+def _cic_paint_dx_impl(displacements, weight=1. , halo_size=0 , chunk_size=2**24):
 
     halo_x, _ = halo_size[0]
     halo_y, _ = halo_size[1]
@@ -190,13 +197,13 @@ def cic_paint_dx(displacements,
 
     gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
     spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
+    weight_spec = P() if jnp.isscalar(weight) else spec
     grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
                                   halo_size=halo_size,
-                                  weight=weight,
                                   chunk_size=chunk_size),
                           gpu_mesh=gpu_mesh,
-                          in_specs=spec,
-                          out_specs=spec)(displacements)
+                          in_specs=(spec, weight_spec),
+                          out_specs=spec)(displacements , weight)
 
     grid_mesh = halo_exchange(grid_mesh,
                               halo_extents=halo_extents,