mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-04 11:10:53 +00:00
painting now accepts pytrees
This commit is contained in:
parent
9e203b5680
commit
38f6599974
2 changed files with 8 additions and 8 deletions
|
@ -48,7 +48,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
|
||||
positions_structure = jax.tree_structure(positions)
|
||||
positions_structure = jax.tree.structure(positions)
|
||||
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh))
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
|
@ -171,7 +171,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|||
indexing='ij') , axis=0), particle_mesh)
|
||||
|
||||
particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh)
|
||||
pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
||||
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
||||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
|
@ -222,7 +222,7 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|||
jnp.arange(original_shape[2]),
|
||||
indexing='ij') , axis=0), grid_mesh)
|
||||
|
||||
pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
||||
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
disp = disp.reshape([-1, 3])
|
||||
|
||||
|
|
|
@ -110,9 +110,9 @@ def _scatter_chunk(carry, chunk):
|
|||
spatial_shape)
|
||||
# scatter
|
||||
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
|
||||
mesh_structure = jax.tree_structure(mesh)
|
||||
mesh_structure = jax.tree.structure(mesh)
|
||||
val_flat = jax.tree.leaves(val)
|
||||
val_tree = jax.tree_unflatten(mesh_structure, val_flat)
|
||||
val_tree = jax.tree.unflatten(mesh_structure, val_flat)
|
||||
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac)
|
||||
carry = mesh, offset, cell_size, mesh_shape
|
||||
return carry, None
|
||||
|
@ -189,9 +189,9 @@ def _gather_chunk(carry, chunk):
|
|||
# gather
|
||||
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
|
||||
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
|
||||
ind_structure = jax.tree_structure(ind)
|
||||
frac_structure = jax.tree_structure(frac)
|
||||
mesh_structure = jax.tree_structure(mesh)
|
||||
ind_structure = jax.tree.structure(ind)
|
||||
frac_structure = jax.tree.structure(frac)
|
||||
mesh_structure = jax.tree.structure(mesh)
|
||||
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac)
|
||||
|
||||
return carry, val
|
||||
|
|
Loading…
Add table
Reference in a new issue