diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 398e0ed..b343fc8 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -13,7 +13,8 @@ except ImportError: import jax.numpy as jnp from jax._src import mesh as mesh_lib from jax.experimental.shard_map import shard_map - +from functools import partial +from jax.sharding import PartitionSpec as P def autoshmap(f: Callable, in_specs: Specs, @@ -34,13 +35,52 @@ def fft3d(x): return jaxdecomp.pfft3d(x.astype(jnp.complex64)) else: return jnp.fft.rfftn(x) - + def ifft3d(x): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): return jaxdecomp.pifft3d(x).real else: return jnp.fft.irfftn(x) + +def halo_exchange(x): + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): + return jaxdecomp.halo_exchange(x) + else: + return x + +@partial(autoshmap, + in_specs=(P('x', 'y'), P()), + out_specs=P('x', 'y')) +def slice_pad_impl(x, pad_width): + return jnp.pad(x, pad_width) + +@partial(autoshmap, + in_specs=(P('x', 'y'), P()), + out_specs=P('x', 'y')) +def slice_unpad_impl(x, pad_width): + halo_x, _ = pad_width[0] + halo_y, _ = pad_width[0] + + # Apply corrections along x + x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2]) + x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:]) + # Apply corrections along y + x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2]) + x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:]) + return x + +def slice_pad(x, pad_width): + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): + return slice_pad_impl(x, pad_width) + else: + return x + +def slice_unpad(x, pad_width): + if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): + return slice_unpad_impl(x, pad_width) + else: + return x def get_local_shape(mesh_shape):