Shared operation in fourrier space now take inverted sharding axis for

slabs
This commit is contained in:
Wassim KABALAN 2024-08-02 21:21:00 +02:00
parent 8c5bd76c33
commit 75604d2396
2 changed files with 23 additions and 8 deletions

View file

@ -35,15 +35,26 @@ def autoshmap(f: Callable,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = True,
auto: frozenset[AxisName] = frozenset()):
auto: frozenset[AxisName] = frozenset(),
in_fourrier_space=False) -> Callable:
"""Helper function to wrap the provided function in a shard map if
the code is being executed in a mesh context."""
mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty:
return f
else:
if in_fourrier_space and 1 in mesh.devices.shape:
in_specs , out_specs = switch_specs((in_specs , out_specs))
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def switch_specs(specs):
if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs)
return P(*new_axes)
elif isinstance(specs, tuple):
return tuple(switch_specs(sub_spec) for sub_spec in specs)
else:
raise TypeError("Element must be either a PartitionSpec or a tuple")
def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
@ -87,17 +98,21 @@ def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
def slice_unpad_impl(x, pad_width):
halo_x, _ = pad_width[0]
halo_y, _ = pad_width[0]
halo_y, _ = pad_width[1]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
return x[halo_x:-halo_x, halo_y:-halo_y, :]
unpad_slice = [slice(None)] * 3
if halo_x > 0:
unpad_slice[0] = slice(halo_x , -halo_x)
if halo_y > 0:
unpad_slice[1] = slice(halo_y , -halo_y)
return x[tuple(unpad_slice)]
def slice_pad(x, pad_width):
mesh = mesh_lib.thread_resources.env.physical_mesh

View file

@ -46,7 +46,7 @@ def fftk(shape, dtype=np.float32):
@partial(autoshmap,
in_specs=(P('x'), P('y'), P(None)),
out_specs=(P('x'), P(None, 'y'), P(None)))
out_specs=(P('x'), P(None, 'y'), P(None)),in_fourrier_space=True)
def get_kvec(ky, kz, kx):
return (ky.reshape([-1, 1, 1]),
kz.reshape([1, -1, 1]),
@ -73,7 +73,7 @@ def interpolate_power_spectrum(input, k, pk):
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape(-1), k, pk
).reshape(x.shape)
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'))(input)
return autoshmap(pk_fn, in_specs=P('x', 'y'), out_specs=P('x', 'y'),in_fourrier_space=True)(input)
def gradient_kernel(kvec, direction, order=1):