This commit is contained in:
Wassim KABALAN 2024-08-03 00:23:40 +02:00
parent 831291c1f9
commit ece8c93540
12 changed files with 210 additions and 170 deletions

View file

@ -44,17 +44,20 @@ def autoshmap(f: Callable,
return f
else:
if in_fourrier_space and 1 in mesh.devices.shape:
in_specs , out_specs = switch_specs((in_specs , out_specs))
in_specs, out_specs = switch_specs((in_specs, out_specs))
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def switch_specs(specs):
if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs)
return P(*new_axes)
elif isinstance(specs, tuple):
return tuple(switch_specs(sub_spec) for sub_spec in specs)
else:
raise TypeError("Element must be either a PartitionSpec or a tuple")
if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
for ax in specs)
return P(*new_axes)
elif isinstance(specs, tuple):
return tuple(switch_specs(sub_spec) for sub_spec in specs)
else:
raise TypeError("Element must be either a PartitionSpec or a tuple")
def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
@ -105,14 +108,15 @@ def slice_unpad_impl(x, pad_width):
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
unpad_slice = [slice(None)] * 3
if halo_x > 0:
unpad_slice[0] = slice(halo_x , -halo_x)
unpad_slice[0] = slice(halo_x, -halo_x)
if halo_y > 0:
unpad_slice[1] = slice(halo_y , -halo_y)
return x[tuple(unpad_slice)]
unpad_slice[1] = slice(halo_y, -halo_y)
return x[tuple(unpad_slice)]
def slice_pad(x, pad_width):
mesh = mesh_lib.thread_resources.env.physical_mesh