painting now accepts pytrees

This commit is contained in:
Wassim Kabalan 2025-01-20 22:39:40 +01:00
parent 9e203b5680
commit 38f6599974
2 changed files with 8 additions and 8 deletions

View file

@ -48,7 +48,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions_structure = jax.tree_structure(positions)
positions_structure = jax.tree.structure(positions)
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh))
positions = positions.reshape((*grid_mesh.shape, 3))
@ -171,7 +171,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
indexing='ij') , axis=0), particle_mesh)
particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh)
pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
@ -222,7 +222,7 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
jnp.arange(original_shape[2]),
indexing='ij') , axis=0), grid_mesh)
pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])

View file

@ -110,9 +110,9 @@ def _scatter_chunk(carry, chunk):
spatial_shape)
# scatter
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
mesh_structure = jax.tree_structure(mesh)
mesh_structure = jax.tree.structure(mesh)
val_flat = jax.tree.leaves(val)
val_tree = jax.tree_unflatten(mesh_structure, val_flat)
val_tree = jax.tree.unflatten(mesh_structure, val_flat)
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac)
carry = mesh, offset, cell_size, mesh_shape
return carry, None
@ -189,9 +189,9 @@ def _gather_chunk(carry, chunk):
# gather
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
ind_structure = jax.tree_structure(ind)
frac_structure = jax.tree_structure(frac)
mesh_structure = jax.tree_structure(mesh)
ind_structure = jax.tree.structure(ind)
frac_structure = jax.tree.structure(frac)
mesh_structure = jax.tree.structure(mesh)
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac)
return carry, val