diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 3083f08..eb5c402 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -19,37 +19,37 @@ def _cic_paint_impl(grid_mesh, positions, weight=None): """ positions = positions.reshape([-1, 3]) - positions = jnp.expand_dims(positions, 1) - floor = jnp.floor(positions) + positions = jax.tree.map(lambda p : jnp.expand_dims(p , 1) , positions) + floor = jax.tree.map(jnp.floor , positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) neighboor_coords = floor + connection - kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = 1. - jax.tree.map(jnp.abs , (positions - neighboor_coords)) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] if weight is not None: - if jnp.isscalar(weight): - kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + if jax.tree.all(jax.tree.map(jnp.isscalar, weight)): + kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1) + , k) , kernel , weight) else: - kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), - kernel) + kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight) - neighboor_coords = jnp.mod( - neighboor_coords.reshape([-1, 8, 3]).astype('int32'), - jnp.array(grid_mesh.shape)) + neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords) dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(grid_mesh, neighboor_coords, - kernel.reshape([-1, 8]), dnums) + mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel) + return mesh @partial(jax.jit, static_argnums=(3, 4)) def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): + positions_structure = jax.tree_structure(positions) + grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh)) positions = positions.reshape((*grid_mesh.shape, 3)) halo_size, halo_extents = get_halo_size(halo_size, sharding) @@ -79,25 +79,25 @@ def _cic_read_impl(grid_mesh, positions): # Reshape positions to a flat list of 3D coordinates positions = positions.reshape([-1, 3]) # Expand dimensions to calculate neighbor coordinates - positions = jnp.expand_dims(positions, 1) + positions = jax.tree.map(lambda p : jnp.expand_dims(p, 1) , positions) # Floor the positions to get the base grid cell for each particle - floor = jnp.floor(positions) + floor = jax.tree.map(jnp.floor , positions) # Define connections to calculate all neighbor coordinates connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) # Calculate the 8 neighboring coordinates neighboor_coords = floor + connection # Calculate kernel weights based on distance from each neighboring coordinate - kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = 1. - jax.tree.map(jnp.abs , positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] # Modulo operation to wrap around edges if necessary - neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), - jnp.array(grid_mesh.shape)) + neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32') + ,jnp.array(grid_mesh.shape)) , neighboor_coords) + # Ensure grid_mesh shape is as expected # Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel - return (grid_mesh[neighboor_coords[..., 0], - neighboor_coords[..., 1], - neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable + grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel) + return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable @partial(jax.jit, static_argnums=(2, 3)) @@ -157,7 +157,7 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): halo_y, _ = halo_size[1] original_shape = displacements.shape - particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32') + particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements) if not jnp.isscalar(weight): if weight.shape != original_shape[:-1]: raise ValueError("Weight shape must match particle shape") @@ -165,13 +165,13 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): weight = weight.flatten() # Padding is forced to be zero in a single gpu run - a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]), - jnp.arange(particle_mesh.shape[1]), - jnp.arange(particle_mesh.shape[2]), - indexing='ij') - - particle_mesh = jnp.pad(particle_mesh, halo_size) - pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) + a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]), + jnp.arange(x.shape[1]), + jnp.arange(x.shape[2]), + indexing='ij') , axis=0), particle_mesh) + + particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh) + pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) return scatter(pmid.reshape([-1, 3]), displacements.reshape([-1, 3]), particle_mesh, @@ -217,9 +217,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size): jnp.arange(original_shape[1]), jnp.arange(original_shape[2]), indexing='ij') + a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]), + jnp.arange(original_shape[1]), + jnp.arange(original_shape[2]), + indexing='ij') , axis=0), grid_mesh) - pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) - + pmid = jax.tree_map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) pmid = pmid.reshape([-1, 3]) disp = disp.reshape([-1, 3]) diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index cf68f9d..e17a1af 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -28,8 +28,8 @@ def _chunk_split(ptcl_num, chunk_size, *arrays): def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape): """Multilinear enmeshing.""" - base_indices = jnp.asarray(base_indices) - displacements = jnp.asarray(displacements) + base_indices = jax.tree.map(jnp.asarray , base_indices) + displacements = jax.tree.map(jnp.asarray , displacements) with jax.experimental.enable_x64(): cell_size = jnp.float64( cell_size) if new_cell_size is not None else jnp.array( @@ -61,7 +61,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_displacements = particle_positions - new_indices * new_cell_size if base_shape is not None: - new_displacements -= jnp.rint( + new_displacements -= jax.tree.map(jnp.rint , new_displacements / grid_length ) * grid_length # also abs(new_displacements) < new_cell_size is expected @@ -89,7 +89,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, if base_shape is not None: new_indices %= base_shape - weights = 1 - jnp.abs(new_displacements) + weights = 1 - jax.tree.map(jnp.abs , new_displacements) if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None new_indices = jnp.where(new_indices < 0, new_shape, new_indices) @@ -109,8 +109,11 @@ def _scatter_chunk(carry, chunk): ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, spatial_shape) # scatter - ind = tuple(ind[..., i] for i in range(spatial_ndim)) - mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac)) + ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) + mesh_structure = jax.tree_structure(mesh) + val_flat = jax.tree.leaves(val) + val_tree = jax.tree_unflatten(mesh_structure, val_flat) + mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac) carry = mesh, offset, cell_size, mesh_shape return carry, None @@ -122,9 +125,10 @@ def scatter(pmid, val=1., offset=0, cell_size=1.): + ptcl_num, spatial_ndim = pmid.shape - val = jnp.asarray(val) - mesh = jnp.asarray(mesh) + val = jax.tree.map(jnp.asarray , val) + mesh = jax.tree.map(jnp.asarray , mesh) remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) carry = mesh, offset, cell_size, mesh.shape if remainder is not None: @@ -147,9 +151,9 @@ def _chunk_cat(remainder_array, chunked_array): def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.): ptcl_num, spatial_ndim = pmid.shape - mesh = jnp.asarray(mesh) + mesh = jax.tree.map(jnp.asarray , mesh) - val = jnp.asarray(val) + val = jax.tree.map(jnp.asarray , val) if mesh.shape[spatial_ndim:] != val.shape[1:]: raise ValueError('channel shape mismatch: ' @@ -183,8 +187,11 @@ def _gather_chunk(carry, chunk): spatial_shape) # gather - ind = tuple(ind[..., i] for i in range(spatial_ndim)) - frac = jnp.expand_dims(frac, chan_axis) - val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1) + ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) + frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac) + ind_structure = jax.tree_structure(ind) + frac_structure = jax.tree_structure(frac) + mesh_structure = jax.tree_structure(mesh) + val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac) return carry, val