diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 1337650..643e85f 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -11,7 +11,7 @@ from jaxpm.kernels import cic_compensation, fftk from jaxpm.painting_utils import gather, scatter -def cic_paint_impl(mesh, displacement, weight=None): +def cic_paint_impl(grid_mesh, displacement, weight=None): """ Paints positions onto mesh mesh: [nx, ny, nz] displacement field: [nx, ny, nz, 3] @@ -36,30 +36,34 @@ def cic_paint_impl(mesh, displacement, weight=None): neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), - jnp.array(mesh.shape)) + jnp.array(grid_mesh.shape)) dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0, 1, 2), scatter_dims_to_operand_dims=(0, 1, 2)) - mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]), - dnums) + mesh = lax.scatter_add(grid_mesh, neighboor_coords, + kernel.reshape([-1, 8]), dnums) return mesh -@partial(jax.jit, static_argnums=(2, )) -def cic_paint(mesh, positions, halo_size=0, weight=None): +@partial(jax.jit, static_argnums=(2, 3, 4)) +def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None): - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) - mesh = autoshmap(cic_paint_impl, - in_specs=(P('x', 'y'), P('x', 'y'), P()), - out_specs=P('x', 'y'))(mesh, positions, weight) - mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True)) - mesh = slice_unpad(mesh, halo_size) - return mesh + halo_size, halo_extents = get_halo_size(halo_size, sharding) + grid_mesh = slice_pad(grid_mesh, halo_size, sharding) + + gpu_mesh = sharding.mesh if sharding is not None else None + spec = sharding.spec if sharding is not None else P() + grid_mesh = autoshmap(cic_paint_impl, + gpu_mesh=gpu_mesh, + in_specs=(spec, spec, P()), + out_specs=spec)(grid_mesh, positions, weight) + grid_mesh = halo_exchange(grid_mesh, + halo_extents=halo_extents, + halo_periods=(True, True)) + grid_mesh = slice_unpad(grid_mesh, halo_size, sharding) + return grid_mesh def cic_read_impl(mesh, displacement): @@ -92,27 +96,30 @@ def cic_read_impl(mesh, displacement): displacement.shape[:-1]) -@partial(jax.jit, static_argnums=(2, )) -def cic_read(mesh, displacement, halo_size=0): +@partial(jax.jit, static_argnums=(2, 3)) +def cic_read(mesh, displacement, halo_size=0, sharding=None): - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) + halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) + mesh = slice_pad(mesh, halo_size, sharding=sharding) mesh = halo_exchange(mesh, halo_extents=halo_extents, halo_periods=(True, True)) + gpu_mesh = sharding.mesh if sharding is not None else None + spec = sharding.spec if sharding is not None else P() displacement = autoshmap(cic_read_impl, - in_specs=(P('x', 'y'), P('x', 'y')), - out_specs=P('x', 'y'))(mesh, displacement) + gpu_mesh=gpu_mesh, + in_specs=(spec, spec), + out_specs=spec)(mesh, displacement) return displacement def cic_paint_2d(mesh, positions, weight): """ Paints positions onto a 2d mesh - mesh: [nx, ny] - positions: [npart, 2] - weight: [npart] - """ + mesh: [nx, ny] + positions: [npart, 2] + weight: [npart] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) @@ -157,29 +164,32 @@ def cic_paint_dx_impl(displacements, halo_size): return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh) -@partial(jax.jit, static_argnums=(1, )) -def cic_paint_dx(displacements, halo_size=0): +@partial(jax.jit, static_argnums=(1, 2)) +def cic_paint_dx(displacements, halo_size=0, sharding=None): - halo_size, halo_extents = get_halo_size(halo_size) + halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) - mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), - in_specs=(P('x', 'y')), - out_specs=P('x', 'y'))(displacements) + gpu_mesh = sharding.mesh if sharding is not None else None + spec = sharding.spec if sharding is not None else P() + grid_mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), + gpu_mesh=gpu_mesh, + in_specs=spec, + out_specs=spec)(displacements) - mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True)) - mesh = slice_unpad(mesh, halo_size) - return mesh + grid_mesh = halo_exchange(grid_mesh, + halo_extents=halo_extents, + halo_periods=(True, True)) + grid_mesh = slice_unpad(grid_mesh, halo_size, sharding) + return grid_mesh -def cic_read_dx_impl(mesh, halo_size): +def cic_read_dx_impl(grid_mesh, halo_size): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] original_shape = [ - dim - 2 * halo[0] for dim, halo in zip(mesh.shape, halo_size) + dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size) ] a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]), jnp.arange(original_shape[1]), @@ -190,32 +200,36 @@ def cic_read_dx_impl(mesh, halo_size): pmid = pmid.reshape([-1, 3]) - return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape) + return gather(pmid, jnp.zeros_like(pmid), + grid_mesh).reshape(original_shape) -@partial(jax.jit, static_argnums=(1, )) -def cic_read_dx(mesh, halo_size=0): +@partial(jax.jit, static_argnums=(1, 2)) +def cic_read_dx(grid_mesh, halo_size=0, sharding=None): # return mesh - halo_size, halo_extents = get_halo_size(halo_size) - mesh = slice_pad(mesh, halo_size) - mesh = halo_exchange(mesh, - halo_extents=halo_extents, - halo_periods=(True, True)) + halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) + grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) + grid_mesh = halo_exchange(grid_mesh, + halo_extents=halo_extents, + halo_periods=(True, True)) + gpu_mesh = sharding.mesh if sharding is not None else None + spec = sharding.spec if sharding is not None else P() displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), - in_specs=(P('x', 'y')), - out_specs=P('x', 'y'))(mesh) + gpu_mesh=gpu_mesh, + in_specs=(spec), + out_specs=spec)(grid_mesh) return displacements def compensate_cic(field): """ - Compensate for CiC painting - Args: - field: input 3D cic-painted field - Returns: - compensated_field - """ + Compensate for CiC painting + Args: + field: input 3D cic-painted field + Returns: + compensated_field + """ nc = field.shape kvec = fftk(nc)