From d2f1eb2fa49f2f279d7f4c5e81fbdf41c8ff60c8 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Sat, 26 Oct 2024 18:52:51 +0200 Subject: [PATCH] fix painting issue with read_cic --- jaxpm/painting.py | 50 ++++++++++++++++++++++------------------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 37a86a9..838fe38 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -11,17 +11,11 @@ from jaxpm.kernels import cic_compensation, fftk from jaxpm.painting_utils import gather, scatter -def cic_paint_impl(grid_mesh, displacement, weight=None): +def cic_paint_impl(grid_mesh, positions, weight=None): """ Paints positions onto mesh mesh: [nx, ny, nz] displacement field: [nx, ny, nz, 3] """ - part_shape = displacement.shape - positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]), - jnp.arange(part_shape[1]), - jnp.arange(part_shape[2]), - indexing='ij'), - axis=-1) + displacement positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) @@ -47,7 +41,7 @@ def cic_paint_impl(grid_mesh, displacement, weight=None): return mesh -@partial(jax.jit, static_argnums=(2, 3, 4)) +#@partial(jax.jit, static_argnums=(2, 3, 4)) def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None): positions = positions.reshape((*grid_mesh.shape, 3)) @@ -66,43 +60,46 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None): halo_periods=(True, True)) grid_mesh = slice_unpad(grid_mesh, halo_size, sharding) - print(f"shape of grid_mesh: {grid_mesh.shape}") return grid_mesh -def cic_read_impl(mesh, displacement): +def cic_read_impl(grid_mesh, positions): """ Paints positions onto mesh mesh: [nx, ny, nz] - displacement: [nx,ny,nz, 3] + positions: [nx,ny,nz, 3] """ - # Compute the position of the particles on a regular grid - part_shape = displacement.shape - positions = jnp.stack(jnp.meshgrid(jnp.arange(part_shape[0]), - jnp.arange(part_shape[1]), - jnp.arange(part_shape[2]), - indexing='ij'), - axis=-1) + displacement + # Save original shape for reshaping output later + original_shape = positions.shape + # Reshape positions to a flat list of 3D coordinates positions = positions.reshape([-1, 3]) + # Expand dimensions to calculate neighbor coordinates positions = jnp.expand_dims(positions, 1) + # Floor the positions to get the base grid cell for each particle floor = jnp.floor(positions) + # Define connections to calculate all neighbor coordinates connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) - + # Calculate the 8 neighboring coordinates neighboor_coords = floor + connection + # Calculate kernel weights based on distance from each neighboring coordinate kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - + # Modulo operation to wrap around edges if necessary neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), - jnp.array(mesh.shape)) - - return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1], - neighboor_coords[..., 3]] * kernel).sum(axis=-1).reshape( - displacement.shape[:-1]) + jnp.array(grid_mesh.shape)) + # Ensure grid_mesh shape is as expected + # Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel + return (grid_mesh[neighboor_coords[..., 0], + neighboor_coords[..., 1], + neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable @partial(jax.jit, static_argnums=(2, 3)) def cic_read(grid_mesh, positions, halo_size=0, sharding=None): + original_shape = positions.shape + positions = positions.reshape((*grid_mesh.shape, 3)) + halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) grid_mesh = halo_exchange(grid_mesh, @@ -114,9 +111,8 @@ def cic_read(grid_mesh, positions, halo_size=0, sharding=None): gpu_mesh=gpu_mesh, in_specs=(spec, spec), out_specs=spec)(grid_mesh, positions) - print(f"shape of displacement: {displacement.shape}") - return displacement + return displacement.reshape(original_shape[:-1]) def cic_paint_2d(mesh, positions, weight):