From 75604d2396cab2fffbc3115d78b8056147398ef1 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 2 Aug 2024 21:21:00 +0200 Subject: [PATCH] Shared operation in fourrier space now take inverted sharding axis for slabs --- jaxpm/distributed.py | 27 +++++++++++++++++++++------ jaxpm/kernels.py | 4 ++-- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index c42979e..9fb0e15 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -35,15 +35,26 @@ def autoshmap(f: Callable, in_specs: Specs, out_specs: Specs, check_rep: bool = True, - auto: frozenset[AxisName] = frozenset()): + auto: frozenset[AxisName] = frozenset(), + in_fourrier_space=False) -> Callable: """Helper function to wrap the provided function in a shard map if the code is being executed in a mesh context.""" mesh = mesh_lib.thread_resources.env.physical_mesh if mesh.empty: return f else: + if in_fourrier_space and 1 in mesh.devices.shape: + in_specs , out_specs = switch_specs((in_specs , out_specs)) return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) +def switch_specs(specs): + if isinstance(specs, P): + new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs) + return P(*new_axes) + elif isinstance(specs, tuple): + return tuple(switch_specs(sub_spec) for sub_spec in specs) + else: + raise TypeError("Element must be either a PartitionSpec or a tuple") def fft3d(x): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): @@ -87,17 +98,21 @@ def halo_exchange(x, halo_extents, halo_periods=(True, True, True)): def slice_unpad_impl(x, pad_width): halo_x, _ = pad_width[0] - halo_y, _ = pad_width[0] - + halo_y, _ = pad_width[1] # Apply corrections along x x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2]) x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:]) # Apply corrections along y x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2]) x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:]) - - return x[halo_x:-halo_x, halo_y:-halo_y, :] - + + unpad_slice = [slice(None)] * 3 + if halo_x > 0: + unpad_slice[0] = slice(halo_x , -halo_x) + if halo_y > 0: + unpad_slice[1] = slice(halo_y , -halo_y) + + return x[tuple(unpad_slice)] def slice_pad(x, pad_width): mesh = mesh_lib.thread_resources.env.physical_mesh diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 0025fa8..bfb7e7e 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -46,7 +46,7 @@ def fftk(shape, dtype=np.float32): @partial(autoshmap, in_specs=(P('x'), P('y'), P(None)), - out_specs=(P('x'), P(None, 'y'), P(None))) + out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True) def get_kvec(ky, kz, kx): return (ky.reshape([-1, 1, 1]), kz.reshape([1, -1, 1]), @@ -73,7 +73,7 @@ def interpolate_power_spectrum(input, k, pk): pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk ).reshape(x.shape) - return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'))(input) + return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input) def gradient_kernel(kvec, direction, order=1):