mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-16 16:10:54 +00:00
remove fourrier_space in autoshmap
This commit is contained in:
parent
01b952701e
commit
ff1c5e8362
1 changed files with 6 additions and 19 deletions
|
@ -31,34 +31,21 @@ from jax.sharding import PartitionSpec as P
|
||||||
# return autoshmap(foo_impl)()
|
# return autoshmap(foo_impl)()
|
||||||
|
|
||||||
|
|
||||||
def autoshmap(f: Callable,
|
def autoshmap(
|
||||||
in_specs: Specs,
|
f: Callable,
|
||||||
out_specs: Specs,
|
in_specs: Specs,
|
||||||
check_rep: bool = True,
|
out_specs: Specs,
|
||||||
auto: frozenset[AxisName] = frozenset(),
|
check_rep: bool = True,
|
||||||
in_fourrier_space=False) -> Callable:
|
auto: frozenset[AxisName] = frozenset()) -> Callable:
|
||||||
"""Helper function to wrap the provided function in a shard map if
|
"""Helper function to wrap the provided function in a shard map if
|
||||||
the code is being executed in a mesh context."""
|
the code is being executed in a mesh context."""
|
||||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||||
if mesh.empty:
|
if mesh.empty:
|
||||||
return f
|
return f
|
||||||
else:
|
else:
|
||||||
if in_fourrier_space and 1 in mesh.devices.shape:
|
|
||||||
in_specs, out_specs = switch_specs((in_specs, out_specs))
|
|
||||||
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||||
|
|
||||||
|
|
||||||
def switch_specs(specs):
|
|
||||||
if isinstance(specs, P):
|
|
||||||
new_axes = tuple('y' if ax == 'x' else 'x' if ax == 'y' else ax
|
|
||||||
for ax in specs)
|
|
||||||
return P(*new_axes)
|
|
||||||
elif isinstance(specs, tuple):
|
|
||||||
return tuple(switch_specs(sub_spec) for sub_spec in specs)
|
|
||||||
else:
|
|
||||||
raise TypeError("Element must be either a PartitionSpec or a tuple")
|
|
||||||
|
|
||||||
|
|
||||||
def fft3d(x):
|
def fft3d(x):
|
||||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||||
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||||
|
|
Loading…
Add table
Reference in a new issue