From b09580d59eb89ae7343d5a04604d83145d4b2dd0 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Wed, 30 Oct 2024 01:57:32 +0100 Subject: [PATCH] Allow applying weights with relative cic paint --- jaxpm/painting.py | 35 +++++++++++++++++++++++++++-------- jaxpm/painting_utils.py | 5 +---- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index c26f895..e4363ff 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -16,6 +16,7 @@ def cic_paint_impl(grid_mesh, positions, weight=None): mesh: [nx, ny, nz] displacement field: [nx, ny, nz, 3] """ + positions = positions.reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) @@ -26,7 +27,11 @@ def cic_paint_impl(grid_mesh, positions, weight=None): kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] if weight is not None: - kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + if jnp.isscalar(weight): + kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + else: + kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), + kernel) neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), @@ -144,14 +149,18 @@ def cic_paint_2d(mesh, positions, weight): return mesh -def cic_paint_dx_impl(displacements, halo_size): +def cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] original_shape = displacements.shape particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') - + if not jnp.isscalar(weight): + if weight.shape != original_shape[:-1]: + raise ValueError("Weight shape must match particle shape") + else: + weight = weight.flatten() # Padding is forced to be zero in a single gpu run a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), @@ -161,18 +170,28 @@ def cic_paint_dx_impl(displacements, halo_size): particle_mesh = jnp.pad(particle_mesh, halo_size) pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) - pmid = pmid.reshape([-1, 3]) - return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh) + return scatter(pmid.reshape([-1, 3]), + displacements.reshape([-1, 3]), + particle_mesh, + chunk_size=2**24, + val=weight) -@partial(jax.jit, static_argnums=(1, 2)) -def cic_paint_dx(displacements, halo_size=0, sharding=None): +@partial(jax.jit, static_argnums=(1, 2, 4)) +def cic_paint_dx(displacements, + halo_size=0, + sharding=None, + weight=1.0, + chunk_size=2**24): halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) gpu_mesh = sharding.mesh if sharding is not None else None spec = sharding.spec if sharding is not None else P() - grid_mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), + grid_mesh = autoshmap(partial(cic_paint_dx_impl, + halo_size=halo_size, + weight=weight, + chunk_size=chunk_size), gpu_mesh=gpu_mesh, in_specs=spec, out_specs=spec)(displacements) diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index a0319a5..8742ccd 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -104,8 +104,7 @@ def _scatter_chunk(carry, chunk): spatial_shape) # scatter ind = tuple(ind[..., i] for i in range(spatial_ndim)) - mesh = mesh.at[ind].add(val * frac) - + mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac)) carry = mesh, offset, cell_size, mesh_shape return carry, None @@ -117,11 +116,9 @@ def scatter(pmid, val=1., offset=0, cell_size=1.): - ptcl_num, spatial_ndim = pmid.shape val = jnp.asarray(val) mesh = jnp.asarray(mesh) - remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) carry = mesh, offset, cell_size, mesh.shape if remainder is not None: