mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 12:31:11 +00:00
make distributed pm work in single controller
This commit is contained in:
parent
9c94f994ff
commit
375f2048e4
2 changed files with 176 additions and 13 deletions
|
@ -1,12 +1,12 @@
|
|||
import os
|
||||
|
||||
from distributed_utils import initialize_distributed, is_on_cluster
|
||||
|
||||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
||||
initialize_distributed()
|
||||
import jax
|
||||
|
||||
jax.distributed.initialize()
|
||||
|
||||
rank = jax.process_index()
|
||||
size = jax.process_count()
|
||||
size = jax.device_count()
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
|
@ -24,9 +24,9 @@ size = 256
|
|||
mesh_shape = [size] * 3
|
||||
box_size = [float(size)] * 3
|
||||
snapshots = jnp.linspace(0.1, 1., 4)
|
||||
halo_size = 64
|
||||
halo_size = 32
|
||||
pdims = (1, 1)
|
||||
if jax.device_count() > 1:
|
||||
|
||||
pdims = (4, 2)
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||
|
@ -51,7 +51,8 @@ def run_simulation(omega_c, sigma8):
|
|||
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
||||
|
||||
return initial_conditions, cic_paint_dx(dx,
|
||||
halo_size=halo_size), None, None
|
||||
# Evolve the simulation forward
|
||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
||||
term = ODETerm(
|
||||
|
@ -80,6 +81,7 @@ def run_simulation(omega_c, sigma8):
|
|||
|
||||
|
||||
# Run the simulation
|
||||
print(f"mesh {mesh}")
|
||||
if jax.device_count() > 1:
|
||||
with mesh:
|
||||
init, field, final_fields, stats = run_simulation(0.32, 0.8)
|
||||
|
@ -89,13 +91,29 @@ else:
|
|||
|
||||
# # Print the statistics
|
||||
print(stats)
|
||||
print(f"done now saving")
|
||||
if is_on_cluster():
|
||||
rank = jax.process_index()
|
||||
# # save the final state
|
||||
np.save(f'initial_conditions_{rank}.npy', init.addressable_data(0))
|
||||
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||
|
||||
# # save the final state
|
||||
np.save(f'initial_conditions_{rank}.npy', init.addressable_data(0))
|
||||
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||
if final_fields is not None:
|
||||
for i, final_field in enumerate(final_fields):
|
||||
np.save(f'final_field_{i}_{rank}.npy',
|
||||
final_field.addressable_data(0))
|
||||
else:
|
||||
indices = np.arange(len(init.addressable_shards)).reshape(
|
||||
pdims[::-1]).transpose().flatten()
|
||||
print(f"indices {indices}")
|
||||
for i in np.arange(len(init.addressable_shards)):
|
||||
|
||||
if final_fields is not None:
|
||||
for i, final_field in enumerate(final_fields):
|
||||
np.save(f'final_field_{i}_{rank}.npy', final_field.addressable_data(0))
|
||||
np.save(f'initial_conditions_{i}.npy', init.addressable_data(i))
|
||||
np.save(f'field_{i}.npy', field.addressable_data(i))
|
||||
|
||||
if final_fields is not None:
|
||||
for j, final_field in enumerate(final_fields):
|
||||
np.save(f'final_field_{j}_{i}.npy',
|
||||
final_field.addressable_data(i))
|
||||
|
||||
print(f"Finished!!")
|
||||
|
|
145
scripts/distributed_utils.py
Normal file
145
scripts/distributed_utils.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import os
|
||||
from math import prod
|
||||
|
||||
|
||||
|
||||
setup_done = False
|
||||
on_cluster = False
|
||||
|
||||
|
||||
def is_on_cluster():
|
||||
global on_cluster
|
||||
return on_cluster
|
||||
|
||||
|
||||
def initialize_distributed():
|
||||
global setup_done
|
||||
global on_cluster
|
||||
if not setup_done:
|
||||
if "SLURM_JOB_ID" in os.environ:
|
||||
on_cluster = True
|
||||
print("Running on cluster")
|
||||
import jax
|
||||
jax.distributed.initialize()
|
||||
setup_done = True
|
||||
on_cluster = True
|
||||
else:
|
||||
print("Running locally")
|
||||
setup_done = True
|
||||
on_cluster = False
|
||||
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||
import jax
|
||||
|
||||
|
||||
|
||||
|
||||
def compare_sharding(sharding1, sharding2):
|
||||
from jaxdecomp._src.spmd_ops import get_pdims_from_sharding
|
||||
pdims1 = get_pdims_from_sharding(sharding1)
|
||||
pdims2 = get_pdims_from_sharding(sharding2)
|
||||
pdims1 = pdims1 + (1,) * (3 - len(pdims1))
|
||||
pdims2 = pdims2 + (1,) * (3 - len(pdims2))
|
||||
return pdims1 == pdims2
|
||||
|
||||
|
||||
def replace_none_or_zero(value):
|
||||
# Replace None or 0 with 1
|
||||
return 0 if value is None else value
|
||||
|
||||
|
||||
def process_slices(slices_tuple):
|
||||
|
||||
start_product = 1
|
||||
stop_product = 1
|
||||
|
||||
for s in slices_tuple:
|
||||
# Multiply the start and stop values, replacing None/0 with 1
|
||||
start_product *= replace_none_or_zero(s.start)
|
||||
stop_product *= replace_none_or_zero(s.stop)
|
||||
|
||||
# Return the sum of the two products
|
||||
return int(start_product + stop_product)
|
||||
|
||||
|
||||
def device_arange(pdims):
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
||||
sharding = NamedSharding(mesh, P('z', 'y'))
|
||||
|
||||
def generate_aranged(x):
|
||||
x_start = replace_none_or_zero(x[0].start)
|
||||
y_start = replace_none_or_zero(x[1].start)
|
||||
a = jnp.array([[x_start + y_start * pdims[0]]])
|
||||
print(f"index is {x} and value is {a}")
|
||||
return a
|
||||
|
||||
aranged = jax.make_array_from_callback(
|
||||
mesh.devices.shape, sharding, data_callback=generate_aranged)
|
||||
|
||||
return aranged
|
||||
|
||||
|
||||
def create_ones_spmd_array(global_shape, pdims):
|
||||
|
||||
import jax
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
size = jax.device_count()
|
||||
assert (len(global_shape) == 3)
|
||||
assert (len(pdims) == 2)
|
||||
assert (prod(pdims) == size
|
||||
), "The product of pdims must be equal to the number of MPI processes"
|
||||
|
||||
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
|
||||
global_shape[2])
|
||||
|
||||
# Remap to the global array from the local slice
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
||||
sharding = NamedSharding(mesh, P('z', 'y'))
|
||||
global_array = jax.make_array_from_callback(
|
||||
global_shape,
|
||||
sharding,
|
||||
data_callback=lambda _: jax.numpy.ones(local_shape))
|
||||
|
||||
return global_array, mesh
|
||||
|
||||
|
||||
# Helper function to create a 3D array and remap it to the global array
|
||||
def create_spmd_array(global_shape, pdims):
|
||||
|
||||
import jax
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
size = jax.device_count()
|
||||
assert (len(global_shape) == 3)
|
||||
assert (len(pdims) == 2)
|
||||
assert (prod(pdims) == size
|
||||
), "The product of pdims must be equal to the number of MPI processes"
|
||||
|
||||
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
|
||||
global_shape[2])
|
||||
|
||||
# Remap to the global array from the local slicei
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
||||
sharding = NamedSharding(mesh, P('z', 'y'))
|
||||
global_array = jax.make_array_from_callback(
|
||||
global_shape,
|
||||
sharding,
|
||||
data_callback=lambda x: jax.random.normal(
|
||||
jax.random.PRNGKey(process_slices(x)), local_shape))
|
||||
|
||||
return global_array, mesh
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue