From 38f6599974254ec5a1dffa3b8c311afc242a0981 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Mon, 20 Jan 2025 22:39:40 +0100 Subject: [PATCH] painting now accepts pytrees --- jaxpm/painting.py | 6 +++--- jaxpm/painting_utils.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index eb5c402..0210166 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -48,7 +48,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=None): @partial(jax.jit, static_argnums=(3, 4)) def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): - positions_structure = jax.tree_structure(positions) + positions_structure = jax.tree.structure(positions) grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh)) positions = positions.reshape((*grid_mesh.shape, 3)) @@ -171,7 +171,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): indexing='ij') , axis=0), particle_mesh) particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh) - pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) + pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) return scatter(pmid.reshape([-1, 3]), displacements.reshape([-1, 3]), particle_mesh, @@ -222,7 +222,7 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size): jnp.arange(original_shape[2]), indexing='ij') , axis=0), grid_mesh) - pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) + pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) pmid = pmid.reshape([-1, 3]) disp = disp.reshape([-1, 3]) diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index e17a1af..09e6ee5 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -110,9 +110,9 @@ def _scatter_chunk(carry, chunk): spatial_shape) # scatter ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) - mesh_structure = jax.tree_structure(mesh) + mesh_structure = jax.tree.structure(mesh) val_flat = jax.tree.leaves(val) - val_tree = jax.tree_unflatten(mesh_structure, val_flat) + val_tree = jax.tree.unflatten(mesh_structure, val_flat) mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac) carry = mesh, offset, cell_size, mesh_shape return carry, None @@ -189,9 +189,9 @@ def _gather_chunk(carry, chunk): # gather ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac) - ind_structure = jax.tree_structure(ind) - frac_structure = jax.tree_structure(frac) - mesh_structure = jax.tree_structure(mesh) + ind_structure = jax.tree.structure(ind) + frac_structure = jax.tree.structure(frac) + mesh_structure = jax.tree.structure(mesh) val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac) return carry, val