mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
144 lines
4.3 KiB
Python
144 lines
4.3 KiB
Python
import os
|
|
from math import prod
|
|
|
|
setup_done = False
|
|
on_cluster = False
|
|
|
|
|
|
def is_on_cluster():
|
|
global on_cluster
|
|
return on_cluster
|
|
|
|
|
|
def initialize_distributed():
|
|
global setup_done
|
|
global on_cluster
|
|
if not setup_done:
|
|
if "SLURM_JOB_ID" in os.environ:
|
|
on_cluster = True
|
|
print("Running on cluster")
|
|
import jax
|
|
jax.distributed.initialize()
|
|
setup_done = True
|
|
on_cluster = True
|
|
else:
|
|
print("Running locally")
|
|
setup_done = True
|
|
on_cluster = False
|
|
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
|
os.environ[
|
|
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
|
|
import jax
|
|
|
|
|
|
def compare_sharding(sharding1, sharding2):
|
|
from jaxdecomp._src.spmd_ops import get_pdims_from_sharding
|
|
pdims1 = get_pdims_from_sharding(sharding1)
|
|
pdims2 = get_pdims_from_sharding(sharding2)
|
|
pdims1 = pdims1 + (1, ) * (3 - len(pdims1))
|
|
pdims2 = pdims2 + (1, ) * (3 - len(pdims2))
|
|
return pdims1 == pdims2
|
|
|
|
|
|
def replace_none_or_zero(value):
|
|
# Replace None or 0 with 1
|
|
return 0 if value is None else value
|
|
|
|
|
|
def process_slices(slices_tuple):
|
|
|
|
start_product = 1
|
|
stop_product = 1
|
|
|
|
for s in slices_tuple:
|
|
# Multiply the start and stop values, replacing None/0 with 1
|
|
start_product *= replace_none_or_zero(s.start)
|
|
stop_product *= replace_none_or_zero(s.stop)
|
|
|
|
# Return the sum of the two products
|
|
return int(start_product + stop_product)
|
|
|
|
|
|
def device_arange(pdims):
|
|
import jax
|
|
from jax import numpy as jnp
|
|
from jax.experimental import mesh_utils
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
devices = mesh_utils.create_device_mesh(pdims)
|
|
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
|
sharding = NamedSharding(mesh, P('z', 'y'))
|
|
|
|
def generate_aranged(x):
|
|
x_start = replace_none_or_zero(x[0].start)
|
|
y_start = replace_none_or_zero(x[1].start)
|
|
a = jnp.array([[x_start + y_start * pdims[0]]])
|
|
print(f"index is {x} and value is {a}")
|
|
return a
|
|
|
|
aranged = jax.make_array_from_callback(mesh.devices.shape,
|
|
sharding,
|
|
data_callback=generate_aranged)
|
|
|
|
return aranged
|
|
|
|
|
|
def create_ones_spmd_array(global_shape, pdims):
|
|
|
|
import jax
|
|
from jax.experimental import mesh_utils
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
size = jax.device_count()
|
|
assert (len(global_shape) == 3)
|
|
assert (len(pdims) == 2)
|
|
assert (
|
|
prod(pdims) == size
|
|
), "The product of pdims must be equal to the number of MPI processes"
|
|
|
|
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
|
|
global_shape[2])
|
|
|
|
# Remap to the global array from the local slice
|
|
devices = mesh_utils.create_device_mesh(pdims)
|
|
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
|
sharding = NamedSharding(mesh, P('z', 'y'))
|
|
global_array = jax.make_array_from_callback(
|
|
global_shape,
|
|
sharding,
|
|
data_callback=lambda _: jax.numpy.ones(local_shape))
|
|
|
|
return global_array, mesh
|
|
|
|
|
|
# Helper function to create a 3D array and remap it to the global array
|
|
def create_spmd_array(global_shape, pdims):
|
|
|
|
import jax
|
|
from jax.experimental import mesh_utils
|
|
from jax.sharding import Mesh, NamedSharding
|
|
from jax.sharding import PartitionSpec as P
|
|
|
|
size = jax.device_count()
|
|
assert (len(global_shape) == 3)
|
|
assert (len(pdims) == 2)
|
|
assert (
|
|
prod(pdims) == size
|
|
), "The product of pdims must be equal to the number of MPI processes"
|
|
|
|
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
|
|
global_shape[2])
|
|
|
|
# Remap to the global array from the local slicei
|
|
devices = mesh_utils.create_device_mesh(pdims)
|
|
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
|
sharding = NamedSharding(mesh, P('z', 'y'))
|
|
global_array = jax.make_array_from_callback(
|
|
global_shape,
|
|
sharding,
|
|
data_callback=lambda x: jax.random.normal(
|
|
jax.random.PRNGKey(process_slices(x)), local_shape))
|
|
|
|
return global_array, mesh
|