diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 721a971..a094a6b 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -79,6 +79,21 @@ def slice_unpad_impl(x, pad_width): return x[tuple(unpad_slice)] +def slice_pad_impl(x, pad_width): + + x = jnp.pad(x, pad_width=pad_width) + + halo_x, _ = pad_width[0] + halo_y, _ = pad_width[1] + + # Apply corrections along x + x = x.at[halo_x // 2:halo_x].add(x[halo_x:halo_x + halo_x // 2]) + x = x.at[-halo_x:-halo_x // 2].add(x[-(halo_x + halo_x // 2):-halo_x]) + # Apply corrections along y + x = x.at[:, halo_y // 2:halo_y].add(x[:, halo_y:halo_y + halo_y // 2]) + x = x.at[:, -halo_y:-halo_y // 2].add(x[:, -(halo_y + halo_y // 2):-halo_y]) + + return x def slice_pad(x, pad_width, sharding): gpu_mesh = sharding.mesh if sharding is not None else None @@ -86,7 +101,7 @@ def slice_pad(x, pad_width, sharding): pad_width[0][0] > 0 or pad_width[1][0] > 0): assert sharding is not None spec = sharding.spec - return shard_map((partial(jnp.pad, pad_width=pad_width)), + return shard_map((partial(slice_pad_impl, pad_width=pad_width)), mesh=gpu_mesh, in_specs=spec, out_specs=spec)(x)