diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 35d5e8b..bdae517 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -7,37 +7,26 @@ from functools import partial import jax import jax.numpy as jnp +import jaxdecomp from jax import lax -from jax._src import mesh as mesh_lib from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh from jax.sharding import PartitionSpec as P -# NOTE -# This should not be used as a decorator -# Must be used inside a function only -# Example -# BAD -# @autoshmap -# def foo(): -# pass -# GOOD -# def foo(): -# return autoshmap(foo_impl)() - def autoshmap( f: Callable, + gpu_mesh: Mesh | None, in_specs: Specs, out_specs: Specs, check_rep: bool = True, auto: frozenset[AxisName] = frozenset()) -> Callable: """Helper function to wrap the provided function in a shard map if the code is being executed in a mesh context.""" - mesh = mesh_lib.thread_resources.env.physical_mesh - if mesh.empty: + if gpu_mesh is None or gpu_mesh.empty: return f else: - return shard_map(f, mesh, in_specs, out_specs, check_rep, auto) + return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto) def fft3d(x): @@ -48,14 +37,14 @@ def ifft3d(x): return jaxdecomp.pifft3d(x).real -def get_halo_size(halo_size): - mesh = mesh_lib.thread_resources.env.physical_mesh - if mesh.empty: +def get_halo_size(halo_size, sharding): + gpu_mesh = sharding.mesh if sharding is not None else None + if gpu_mesh is None or gpu_mesh.empty: zero_ext = (0, 0, 0) zero_tuple = (0, 0) return (zero_tuple, zero_tuple, zero_tuple), zero_ext else: - pdims = mesh.devices.shape + pdims = gpu_mesh.devices.shape halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size) halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size) @@ -91,44 +80,52 @@ def slice_unpad_impl(x, pad_width): return x[tuple(unpad_slice)] -def slice_pad(x, pad_width): - mesh = mesh_lib.thread_resources.env.physical_mesh - if distributed and not (mesh.empty) and (pad_width[0][0] > 0 - or pad_width[1][0] > 0): - return autoshmap((partial(jnp.pad, pad_width=pad_width)), - in_specs=(P('x', 'y')), - out_specs=P('x', 'y'))(x) +def slice_pad(x, pad_width, sharding): + gpu_mesh = sharding.mesh if sharding is not None else None + if not gpu_mesh is None and not (gpu_mesh.empty) and ( + pad_width[0][0] > 0 or pad_width[1][0] > 0): + assert sharding is not None + spec = sharding.spec + return shard_map((partial(jnp.pad, pad_width=pad_width)), + mesh=gpu_mesh, + in_specs=spec, + out_specs=spec)(x) else: return x -def slice_unpad(x, pad_width): - mesh = mesh_lib.thread_resources.env.physical_mesh - if distributed and not (mesh.empty) and (pad_width[0][0] > 0 - or pad_width[1][0] > 0): - return autoshmap(partial(slice_unpad_impl, pad_width=pad_width), - in_specs=(P('x', 'y')), - out_specs=P('x', 'y'))(x) +def slice_unpad(x, pad_width, sharding): + mesh = sharding.mesh if sharding is not None else None + if not mesh is None and not (mesh.empty) and (pad_width[0][0] > 0 + or pad_width[1][0] > 0): + assert sharding is not None + spec = sharding.spec + return shard_map(partial(slice_unpad_impl, pad_width=pad_width), + mesh=mesh, + in_specs=spec, + out_specs=spec)(x) else: return x -def get_local_shape(mesh_shape): +def get_local_shape(mesh_shape, sharding): """ Helper function to get the local size of a mesh given the global size. """ - if mesh_lib.thread_resources.env.physical_mesh.empty: + gpu_mesh = sharding.mesh if sharding is not None else None + if gpu_mesh is None or gpu_mesh.empty: return mesh_shape else: - pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape + pdims = gpu_mesh.devices.shape return [ mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2] ] -def normal_field(mesh_shape, seed): +def normal_field(mesh_shape, seed, sharding): """Generate a Gaussian random field with the given power spectrum.""" - if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): - local_mesh_shape = get_local_shape(mesh_shape) + gpu_mesh = sharding.mesh if sharding is not None else None + if not gpu_mesh is None and not (gpu_mesh.empty): + local_mesh_shape = get_local_shape(mesh_shape, sharding) size = jax.device_count() # rank = jax.process_index() @@ -136,16 +133,36 @@ def normal_field(mesh_shape, seed): # to make the code work both in multi host and single controller we can do this trick keys = jax.random.split(seed, size) + spec = sharding.spec + if len(spec) == 1: + x_axis, = spec + y_axis = None + single_axis = True + elif len(spec) == 2: + x_axis, y_axis = spec + if y_axis == None: + single_axis = True + elif x_axis == None: + x_axis = y_axis + single_axis = True + else: + single_axis = False + else: + raise ValueError("Only 1 or 2 axis sharding is supported") + def normal(keys, shape, dtype): - x_index = lax.axis_index('x') - y_index = lax.axis_index('y') - x_size = lax.psum(1, axis_name='x') - idx = x_index + y_index * x_size + idx = lax.axis_index(x_axis) + if single_axis: + y_index = lax.axis_index(y_axis) + x_size = lax.psum(1, axis_name=x_axis) + idx += y_index * x_size + return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype) - return autoshmap( + return shard_map( partial(normal, shape=local_mesh_shape, dtype='float32'), + mesh=gpu_mesh, in_specs=P(None), - out_specs=P('x', 'y'))(keys) # yapf: disable + out_specs=spec)(keys) # yapf: disable else: return jax.random.normal(shape=mesh_shape, key=seed)