diff --git a/jaxpm/painting.py b/jaxpm/painting.py index cf23d63..7160913 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -150,9 +150,8 @@ def cic_paint_dx_impl(displacements, halo_size): jnp.arange(particle_mesh.shape[1]), jnp.arange(particle_mesh.shape[2]), indexing='ij') - + particle_mesh = jnp.pad(particle_mesh, halo_size) - pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) pmid = pmid.reshape([-1, 3]) return scatter(pmid, displacements.reshape([-1, 3]), particle_mesh) @@ -160,12 +159,13 @@ def cic_paint_dx_impl(displacements, halo_size): @partial(jax.jit, static_argnums=(1, )) def cic_paint_dx(displacements, halo_size=0): - + halo_size, halo_extents = get_halo_size(halo_size) - + mesh = autoshmap(partial(cic_paint_dx_impl, halo_size=halo_size), in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(displacements) + mesh = halo_exchange(mesh, halo_extents=halo_extents, halo_periods=(True, True, True)) @@ -173,16 +173,19 @@ def cic_paint_dx(displacements, halo_size=0): return mesh -def cic_read_dx_impl(mesh): +def cic_read_dx_impl(mesh , halo_size): - original_shape = mesh.shape + halo_x, _ = halo_size[0] + halo_y, _ = halo_size[1] + original_shape = [dim - 2 * halo[0] for dim , halo in zip(mesh.shape, halo_size)] a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]), jnp.arange(original_shape[1]), jnp.arange(original_shape[2]), indexing='ij') - pmid = jnp.stack([a, b, c], axis=-1) + pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) + pmid = pmid.reshape([-1, 3]) return gather(pmid, jnp.zeros_like(pmid), mesh).reshape(original_shape) @@ -190,15 +193,16 @@ def cic_read_dx_impl(mesh): @partial(jax.jit, static_argnums=(1, )) def cic_read_dx(mesh, halo_size=0): - + # return mesh halo_size, halo_extents = get_halo_size(halo_size) mesh = slice_pad(mesh, halo_size) mesh = halo_exchange(mesh, halo_extents=halo_extents, halo_periods=(True, True, True)) - displacements = autoshmap(cic_read_dx_impl, + displacements = autoshmap(partial(cic_read_dx_impl , halo_size=halo_size), in_specs=(P('x', 'y')), out_specs=P('x', 'y'))(mesh) + return displacements