remove fourrier_space in autoshmap

This commit is contained in:
Wassim KABALAN 2024-10-21 13:57:26 -04:00
parent 01b952701e
commit ff1c5e8362

View file

@ -31,34 +31,21 @@ from jax.sharding import PartitionSpec as P
# return autoshmap(foo_impl)() # return autoshmap(foo_impl)()
def autoshmap(f: Callable, def autoshmap(
in_specs: Specs, f: Callable,
out_specs: Specs, in_specs: Specs,
check_rep: bool = True, out_specs: Specs,
auto: frozenset[AxisName] = frozenset(), check_rep: bool = True,
in_fourrier_space=False) -> Callable: auto: frozenset[AxisName] = frozenset()) -> Callable:
"""Helper function to wrap the provided function in a shard map if """Helper function to wrap the provided function in a shard map if
the code is being executed in a mesh context.""" the code is being executed in a mesh context."""
mesh = mesh_lib.thread_resources.env.physical_mesh mesh = mesh_lib.thread_resources.env.physical_mesh
if mesh.empty: if mesh.empty:
return f return f
else: else:
if in_fourrier_space and 1 in mesh.devices.shape:
in_specs, out_specs = switch_specs((in_specs, out_specs))
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
def switch_specs(specs):
if isinstance(specs, P):
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
for ax in specs)
return P(*new_axes)
elif isinstance(specs, tuple):
return tuple(switch_specs(sub_spec) for sub_spec in specs)
else:
raise TypeError("Element must be either a PartitionSpec or a tuple")
def fft3d(x): def fft3d(x):
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
return jaxdecomp.pfft3d(x.astype(jnp.complex64)) return jaxdecomp.pfft3d(x.astype(jnp.complex64))