From 4d944f01f270bfcc364429e43f4012013b3790b7 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 22 Oct 2024 12:10:18 -0400 Subject: [PATCH] update demo script --- scripts/distributed_pm.py | 55 +++++----- scripts/distributed_utils.py | 195 +++++++++++++++++------------------ 2 files changed, 126 insertions(+), 124 deletions(-) diff --git a/scripts/distributed_pm.py b/scripts/distributed_pm.py index 5411930..29d0b19 100644 --- a/scripts/distributed_pm.py +++ b/scripts/distributed_pm.py @@ -11,11 +11,12 @@ size = jax.device_count() import jax.numpy as jnp import jax_cosmo as jc import numpy as np -from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve +from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, + SaveAt, diffeqsolve) from jax.experimental import mesh_utils from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P - +from jax.experimental.multihost_utils import process_allgather from jaxpm.kernels import interpolate_power_spectrum from jaxpm.painting import cic_paint_dx from jaxpm.pm import linear_field, lpt, make_ode_fn @@ -26,8 +27,10 @@ box_size = [float(size)] * 3 snapshots = jnp.linspace(0.1, 1., 4) halo_size = 32 pdims = (1, 1) +mesh = None +sharding = None if jax.device_count() > 1: - pdims = (8, 1) + pdims = (2, 4) devices = mesh_utils.create_device_mesh(pdims) mesh = Mesh(devices.T, axis_names=('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) @@ -40,19 +43,27 @@ def run_simulation(omega_c, sigma8): pk = jc.power.linear_matter_power( jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) - pk_fn = lambda x: interpolate_power_spectrum(x, k, pk) + pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding) # Create initial conditions initial_conditions = linear_field(mesh_shape, box_size, pk_fn, + sharding=sharding, seed=jax.random.PRNGKey(0)) cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) # Initial displacement - dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) + dx, p, _ = lpt(cosmo, + initial_conditions, + 0.1, + halo_size=halo_size, + sharding=sharding) + return initial_conditions, cic_paint_dx(dx, + halo_size=halo_size, + sharding=sharding), None, None # Evolve the simulation forward - ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) + ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding) term = ODETerm( lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) solver = LeapfrogMidpoint() @@ -70,22 +81,17 @@ def run_simulation(omega_c, sigma8): # Return the simulation volume at requested states = res.ys - field = cic_paint_dx(dx, halo_size=halo_size) + field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding) final_fields = [ - cic_paint_dx(state[0], halo_size=halo_size) for state in states + cic_paint_dx(state[0], halo_size=halo_size, sharding=sharding) + for state in states ] return initial_conditions, field, final_fields, res.stats # Run the simulation -print(f"mesh {mesh}") -if jax.device_count() > 1: - with mesh: - init, field, final_fields, stats = run_simulation(0.32, 0.8) - -else: - init, field, final_fields, stats = run_simulation(0.32, 0.8) +init, field, final_fields, stats = run_simulation(0.32, 0.8) # # Print the statistics print(stats) @@ -101,17 +107,14 @@ if is_on_cluster(): np.save(f'final_field_{i}_{rank}.npy', final_field.addressable_data(0)) else: - indices = np.arange(len(init.addressable_shards)).reshape( - pdims[::-1]).transpose().flatten() - print(f"indices {indices}") - for i in np.arange(len(init.addressable_shards)): + gathered_init = process_allgather(init, tiled=True) + gathered_field = process_allgather(field, tiled=True) + np.save(f'initial_conditions.npy', gathered_init) + np.save(f'field.npy', gathered_field) - np.save(f'initial_conditions_{i}.npy', init.addressable_data(i)) - np.save(f'field_{i}.npy', field.addressable_data(i)) - - if final_fields is not None: - for j, final_field in enumerate(final_fields): - np.save(f'final_field_{j}_{i}.npy', - final_field.addressable_data(i)) + if final_fields is not None: + for i, final_field in enumerate(final_fields): + gathered_final_field = process_allgather(final_field, tiled=True) + np.save(f'final_field_{i}.npy', gathered_final_field) print(f"Finished!!") diff --git a/scripts/distributed_utils.py b/scripts/distributed_utils.py index 1408d93..63aa64b 100644 --- a/scripts/distributed_utils.py +++ b/scripts/distributed_utils.py @@ -1,145 +1,144 @@ import os from math import prod - - setup_done = False on_cluster = False def is_on_cluster(): - global on_cluster - return on_cluster + global on_cluster + return on_cluster def initialize_distributed(): - global setup_done - global on_cluster - if not setup_done: - if "SLURM_JOB_ID" in os.environ: - on_cluster = True - print("Running on cluster") - import jax - jax.distributed.initialize() - setup_done = True - on_cluster = True - else: - print("Running locally") - setup_done = True - on_cluster = False - os.environ["JAX_PLATFORM_NAME"] = "cpu" - os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" - import jax - - + global setup_done + global on_cluster + if not setup_done: + if "SLURM_JOB_ID" in os.environ: + on_cluster = True + print("Running on cluster") + import jax + jax.distributed.initialize() + setup_done = True + on_cluster = True + else: + print("Running locally") + setup_done = True + on_cluster = False + os.environ["JAX_PLATFORM_NAME"] = "cpu" + os.environ[ + "XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax def compare_sharding(sharding1, sharding2): - from jaxdecomp._src.spmd_ops import get_pdims_from_sharding - pdims1 = get_pdims_from_sharding(sharding1) - pdims2 = get_pdims_from_sharding(sharding2) - pdims1 = pdims1 + (1,) * (3 - len(pdims1)) - pdims2 = pdims2 + (1,) * (3 - len(pdims2)) - return pdims1 == pdims2 + from jaxdecomp._src.spmd_ops import get_pdims_from_sharding + pdims1 = get_pdims_from_sharding(sharding1) + pdims2 = get_pdims_from_sharding(sharding2) + pdims1 = pdims1 + (1, ) * (3 - len(pdims1)) + pdims2 = pdims2 + (1, ) * (3 - len(pdims2)) + return pdims1 == pdims2 def replace_none_or_zero(value): - # Replace None or 0 with 1 - return 0 if value is None else value + # Replace None or 0 with 1 + return 0 if value is None else value def process_slices(slices_tuple): - start_product = 1 - stop_product = 1 + start_product = 1 + stop_product = 1 - for s in slices_tuple: - # Multiply the start and stop values, replacing None/0 with 1 - start_product *= replace_none_or_zero(s.start) - stop_product *= replace_none_or_zero(s.stop) + for s in slices_tuple: + # Multiply the start and stop values, replacing None/0 with 1 + start_product *= replace_none_or_zero(s.start) + stop_product *= replace_none_or_zero(s.stop) - # Return the sum of the two products - return int(start_product + stop_product) + # Return the sum of the two products + return int(start_product + stop_product) def device_arange(pdims): - import jax - from jax import numpy as jnp - from jax.experimental import mesh_utils - from jax.sharding import Mesh, NamedSharding - from jax.sharding import PartitionSpec as P + import jax + from jax import numpy as jnp + from jax.experimental import mesh_utils + from jax.sharding import Mesh, NamedSharding + from jax.sharding import PartitionSpec as P - devices = mesh_utils.create_device_mesh(pdims) - mesh = Mesh(devices.T, axis_names=('z', 'y')) - sharding = NamedSharding(mesh, P('z', 'y')) + devices = mesh_utils.create_device_mesh(pdims) + mesh = Mesh(devices.T, axis_names=('z', 'y')) + sharding = NamedSharding(mesh, P('z', 'y')) - def generate_aranged(x): - x_start = replace_none_or_zero(x[0].start) - y_start = replace_none_or_zero(x[1].start) - a = jnp.array([[x_start + y_start * pdims[0]]]) - print(f"index is {x} and value is {a}") - return a + def generate_aranged(x): + x_start = replace_none_or_zero(x[0].start) + y_start = replace_none_or_zero(x[1].start) + a = jnp.array([[x_start + y_start * pdims[0]]]) + print(f"index is {x} and value is {a}") + return a - aranged = jax.make_array_from_callback( - mesh.devices.shape, sharding, data_callback=generate_aranged) + aranged = jax.make_array_from_callback(mesh.devices.shape, + sharding, + data_callback=generate_aranged) - return aranged + return aranged def create_ones_spmd_array(global_shape, pdims): - import jax - from jax.experimental import mesh_utils - from jax.sharding import Mesh, NamedSharding - from jax.sharding import PartitionSpec as P + import jax + from jax.experimental import mesh_utils + from jax.sharding import Mesh, NamedSharding + from jax.sharding import PartitionSpec as P - size = jax.device_count() - assert (len(global_shape) == 3) - assert (len(pdims) == 2) - assert (prod(pdims) == size - ), "The product of pdims must be equal to the number of MPI processes" + size = jax.device_count() + assert (len(global_shape) == 3) + assert (len(pdims) == 2) + assert ( + prod(pdims) == size + ), "The product of pdims must be equal to the number of MPI processes" - local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], - global_shape[2]) + local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], + global_shape[2]) - # Remap to the global array from the local slice - devices = mesh_utils.create_device_mesh(pdims) - mesh = Mesh(devices.T, axis_names=('z', 'y')) - sharding = NamedSharding(mesh, P('z', 'y')) - global_array = jax.make_array_from_callback( - global_shape, - sharding, - data_callback=lambda _: jax.numpy.ones(local_shape)) + # Remap to the global array from the local slice + devices = mesh_utils.create_device_mesh(pdims) + mesh = Mesh(devices.T, axis_names=('z', 'y')) + sharding = NamedSharding(mesh, P('z', 'y')) + global_array = jax.make_array_from_callback( + global_shape, + sharding, + data_callback=lambda _: jax.numpy.ones(local_shape)) - return global_array, mesh + return global_array, mesh # Helper function to create a 3D array and remap it to the global array def create_spmd_array(global_shape, pdims): - import jax - from jax.experimental import mesh_utils - from jax.sharding import Mesh, NamedSharding - from jax.sharding import PartitionSpec as P + import jax + from jax.experimental import mesh_utils + from jax.sharding import Mesh, NamedSharding + from jax.sharding import PartitionSpec as P - size = jax.device_count() - assert (len(global_shape) == 3) - assert (len(pdims) == 2) - assert (prod(pdims) == size - ), "The product of pdims must be equal to the number of MPI processes" + size = jax.device_count() + assert (len(global_shape) == 3) + assert (len(pdims) == 2) + assert ( + prod(pdims) == size + ), "The product of pdims must be equal to the number of MPI processes" - local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], - global_shape[2]) + local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], + global_shape[2]) - # Remap to the global array from the local slicei - devices = mesh_utils.create_device_mesh(pdims) - mesh = Mesh(devices.T, axis_names=('z', 'y')) - sharding = NamedSharding(mesh, P('z', 'y')) - global_array = jax.make_array_from_callback( - global_shape, - sharding, - data_callback=lambda x: jax.random.normal( - jax.random.PRNGKey(process_slices(x)), local_shape)) - - return global_array, mesh + # Remap to the global array from the local slicei + devices = mesh_utils.create_device_mesh(pdims) + mesh = Mesh(devices.T, axis_names=('z', 'y')) + sharding = NamedSharding(mesh, P('z', 'y')) + global_array = jax.make_array_from_callback( + global_shape, + sharding, + data_callback=lambda x: jax.random.normal( + jax.random.PRNGKey(process_slices(x)), local_shape)) + return global_array, mesh