mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
global mesh no longer needed
This commit is contained in:
parent
38714cf65d
commit
591ee32c55
1 changed files with 63 additions and 46 deletions
|
@ -7,37 +7,26 @@ from functools import partial
|
|||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxdecomp
|
||||
from jax import lax
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
# NOTE
|
||||
# This should not be used as a decorator
|
||||
# Must be used inside a function only
|
||||
# Example
|
||||
# BAD
|
||||
# @autoshmap
|
||||
# def foo():
|
||||
# pass
|
||||
# GOOD
|
||||
# def foo():
|
||||
# return autoshmap(foo_impl)()
|
||||
|
||||
|
||||
def autoshmap(
|
||||
f: Callable,
|
||||
gpu_mesh: Mesh | None,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = True,
|
||||
auto: frozenset[AxisName] = frozenset()) -> Callable:
|
||||
"""Helper function to wrap the provided function in a shard map if
|
||||
the code is being executed in a mesh context."""
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if mesh.empty:
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
return f
|
||||
else:
|
||||
return shard_map(f, mesh, in_specs, out_specs, check_rep, auto)
|
||||
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
|
||||
def fft3d(x):
|
||||
|
@ -48,14 +37,14 @@ def ifft3d(x):
|
|||
return jaxdecomp.pifft3d(x).real
|
||||
|
||||
|
||||
def get_halo_size(halo_size):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if mesh.empty:
|
||||
def get_halo_size(halo_size, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
zero_ext = (0, 0, 0)
|
||||
zero_tuple = (0, 0)
|
||||
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
|
||||
else:
|
||||
pdims = mesh.devices.shape
|
||||
pdims = gpu_mesh.devices.shape
|
||||
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
|
||||
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
|
||||
|
||||
|
@ -91,44 +80,52 @@ def slice_unpad_impl(x, pad_width):
|
|||
return x[tuple(unpad_slice)]
|
||||
|
||||
|
||||
def slice_pad(x, pad_width):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
return autoshmap((partial(jnp.pad, pad_width=pad_width)),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(x)
|
||||
def slice_pad(x, pad_width, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty) and (
|
||||
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map((partial(jnp.pad, pad_width=pad_width)),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def slice_unpad(x, pad_width):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
return autoshmap(partial(slice_unpad_impl, pad_width=pad_width),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(x)
|
||||
def slice_unpad(x, pad_width, sharding):
|
||||
mesh = sharding.mesh if sharding is not None else None
|
||||
if not mesh is None and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map(partial(slice_unpad_impl, pad_width=pad_width),
|
||||
mesh=mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_local_shape(mesh_shape):
|
||||
def get_local_shape(mesh_shape, sharding):
|
||||
""" Helper function to get the local size of a mesh given the global size.
|
||||
"""
|
||||
if mesh_lib.thread_resources.env.physical_mesh.empty:
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
return mesh_shape
|
||||
else:
|
||||
pdims = mesh_lib.thread_resources.env.physical_mesh.devices.shape
|
||||
pdims = gpu_mesh.devices.shape
|
||||
return [
|
||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1], mesh_shape[2]
|
||||
]
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed):
|
||||
def normal_field(mesh_shape, seed, sharding):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape)
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
# rank = jax.process_index()
|
||||
|
@ -136,16 +133,36 @@ def normal_field(mesh_shape, seed):
|
|||
# to make the code work both in multi host and single controller we can do this trick
|
||||
keys = jax.random.split(seed, size)
|
||||
|
||||
spec = sharding.spec
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
y_axis = None
|
||||
single_axis = True
|
||||
elif len(spec) == 2:
|
||||
x_axis, y_axis = spec
|
||||
if y_axis == None:
|
||||
single_axis = True
|
||||
elif x_axis == None:
|
||||
x_axis = y_axis
|
||||
single_axis = True
|
||||
else:
|
||||
single_axis = False
|
||||
else:
|
||||
raise ValueError("Only 1 or 2 axis sharding is supported")
|
||||
|
||||
def normal(keys, shape, dtype):
|
||||
x_index = lax.axis_index('x')
|
||||
y_index = lax.axis_index('y')
|
||||
x_size = lax.psum(1, axis_name='x')
|
||||
idx = x_index + y_index * x_size
|
||||
idx = lax.axis_index(x_axis)
|
||||
if single_axis:
|
||||
y_index = lax.axis_index(y_axis)
|
||||
x_size = lax.psum(1, axis_name=x_axis)
|
||||
idx += y_index * x_size
|
||||
|
||||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
return autoshmap(
|
||||
return shard_map(
|
||||
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=P(None),
|
||||
out_specs=P('x', 'y'))(keys) # yapf: disable
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
||||
|
|
Loading…
Add table
Reference in a new issue