diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 426f9f3..04263ea 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -31,34 +31,21 @@ from jax.sharding import PartitionSpec as P # return autoshmap(foo_impl)() -def autoshmap(f: Callable, - in_specs: Specs, - out_specs: Specs, - check_rep: bool = True, - auto: frozenset[AxisName] = frozenset(), - in_fourrier_space=False) -> Callable: +def autoshmap( + f: Callable, + in_specs: Specs, + out_specs: Specs, + check_rep: bool = True, + auto: frozenset[AxisName] = frozenset()) -> Callable: """Helper function to wrap the provided function in a shard map if the code is being executed in a mesh context.""" mesh = mesh_lib.thread_resources.env.physical_mesh if mesh.empty: return f else: - if in_fourrier_space and 1 in mesh.devices.shape: - in_specs, out_specs = switch_specs((in_specs, out_specs)) return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) -def switch_specs(specs): - if isinstance(specs, P): - new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax - for ax in specs) - return P(*new_axes) - elif isinstance(specs, tuple): - return tuple(switch_specs(sub_spec) for sub_spec in specs) - else: - raise TypeError("Element must be either a PartitionSpec or a tuple") - - def fft3d(x): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): return jaxdecomp.pfft3d(x.astype(jnp.complex64))