mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
update formatting
This commit is contained in:
parent
6408aff1de
commit
319942a6bc
5 changed files with 113 additions and 96 deletions
|
@ -16,11 +16,11 @@ from jax.experimental.shard_map import shard_map
|
|||
|
||||
|
||||
def autoshmap(f: Callable,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = True,
|
||||
auto: frozenset[AxisName] = frozenset()):
|
||||
"""Helper function to wrap the provided function in a shard map if
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = True,
|
||||
auto: frozenset[AxisName] = frozenset()):
|
||||
"""Helper function to wrap the provided function in a shard map if
|
||||
the code is being executed in a mesh context."""
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if mesh.empty:
|
||||
|
@ -28,23 +28,28 @@ def autoshmap(f: Callable,
|
|||
else:
|
||||
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
|
||||
def fft3d(x):
|
||||
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||
else:
|
||||
return jnp.fft.rfftn(x)
|
||||
|
||||
|
||||
def ifft3d(x):
|
||||
if distributed and not(mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.pifft3d(x).real
|
||||
else:
|
||||
return jnp.fft.irfftn(x)
|
||||
|
||||
|
||||
def get_local_shape(mesh_shape):
|
||||
""" Helper function to get the local size of a mesh given the global size.
|
||||
""" Helper function to get the local size of a mesh given the global size.
|
||||
"""
|
||||
if mesh_lib.thread_resources.env.physical_mesh.empty:
|
||||
return mesh_shape
|
||||
else:
|
||||
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
|
||||
return [mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]]
|
||||
if mesh_lib.thread_resources.env.physical_mesh.empty:
|
||||
return mesh_shape
|
||||
else:
|
||||
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
|
||||
return [
|
||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue