mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
format
This commit is contained in:
parent
831291c1f9
commit
ece8c93540
12 changed files with 210 additions and 170 deletions
|
@ -44,17 +44,20 @@ def autoshmap(f: Callable,
|
|||
return f
|
||||
else:
|
||||
if in_fourrier_space and 1 in mesh.devices.shape:
|
||||
in_specs , out_specs = switch_specs((in_specs , out_specs))
|
||||
in_specs, out_specs = switch_specs((in_specs, out_specs))
|
||||
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
|
||||
def switch_specs(specs):
|
||||
if isinstance(specs, P):
|
||||
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax for ax in specs)
|
||||
return P(*new_axes)
|
||||
elif isinstance(specs, tuple):
|
||||
return tuple(switch_specs(sub_spec) for sub_spec in specs)
|
||||
else:
|
||||
raise TypeError("Element must be either a PartitionSpec or a tuple")
|
||||
if isinstance(specs, P):
|
||||
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
|
||||
for ax in specs)
|
||||
return P(*new_axes)
|
||||
elif isinstance(specs, tuple):
|
||||
return tuple(switch_specs(sub_spec) for sub_spec in specs)
|
||||
else:
|
||||
raise TypeError("Element must be either a PartitionSpec or a tuple")
|
||||
|
||||
|
||||
def fft3d(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
|
@ -105,14 +108,15 @@ def slice_unpad_impl(x, pad_width):
|
|||
# Apply corrections along y
|
||||
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||
|
||||
|
||||
unpad_slice = [slice(None)] * 3
|
||||
if halo_x > 0:
|
||||
unpad_slice[0] = slice(halo_x , -halo_x)
|
||||
unpad_slice[0] = slice(halo_x, -halo_x)
|
||||
if halo_y > 0:
|
||||
unpad_slice[1] = slice(halo_y , -halo_y)
|
||||
|
||||
return x[tuple(unpad_slice)]
|
||||
unpad_slice[1] = slice(halo_y, -halo_y)
|
||||
|
||||
return x[tuple(unpad_slice)]
|
||||
|
||||
|
||||
def slice_pad(x, pad_width):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue