update demo script

This commit is contained in:
Wassim KABALAN 2024-10-22 12:10:18 -04:00
parent 105568e8db
commit 4d944f01f2
2 changed files with 126 additions and 124 deletions

View file

@ -11,11 +11,12 @@ size = jax.device_count()
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
import numpy as np import numpy as np
from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
SaveAt, diffeqsolve)
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather
from jaxpm.kernels import interpolate_power_spectrum from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn from jaxpm.pm import linear_field, lpt, make_ode_fn
@ -26,8 +27,10 @@ box_size = [float(size)] * 3
snapshots = jnp.linspace(0.1, 1., 4) snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 32 halo_size = 32
pdims = (1, 1) pdims = (1, 1)
mesh = None
sharding = None
if jax.device_count() > 1: if jax.device_count() > 1:
pdims = (8, 1) pdims = (2, 4)
devices = mesh_utils.create_device_mesh(pdims) devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y')) mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
@ -40,19 +43,27 @@ def run_simulation(omega_c, sigma8):
pk = jc.power.linear_matter_power( pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk) pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
# Create initial conditions # Create initial conditions
initial_conditions = linear_field(mesh_shape, initial_conditions = linear_field(mesh_shape,
box_size, box_size,
pk_fn, pk_fn,
sharding=sharding,
seed=jax.random.PRNGKey(0)) seed=jax.random.PRNGKey(0))
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) dx, p, _ = lpt(cosmo,
initial_conditions,
0.1,
halo_size=halo_size,
sharding=sharding)
return initial_conditions, cic_paint_dx(dx,
halo_size=halo_size,
sharding=sharding), None, None
# Evolve the simulation forward # Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm( term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = LeapfrogMidpoint() solver = LeapfrogMidpoint()
@ -70,22 +81,17 @@ def run_simulation(omega_c, sigma8):
# Return the simulation volume at requested # Return the simulation volume at requested
states = res.ys states = res.ys
field = cic_paint_dx(dx, halo_size=halo_size) field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
final_fields = [ final_fields = [
cic_paint_dx(state[0], halo_size=halo_size) for state in states cic_paint_dx(state[0], halo_size=halo_size, sharding=sharding)
for state in states
] ]
return initial_conditions, field, final_fields, res.stats return initial_conditions, field, final_fields, res.stats
# Run the simulation # Run the simulation
print(f"mesh {mesh}") init, field, final_fields, stats = run_simulation(0.32, 0.8)
if jax.device_count() > 1:
with mesh:
init, field, final_fields, stats = run_simulation(0.32, 0.8)
else:
init, field, final_fields, stats = run_simulation(0.32, 0.8)
# # Print the statistics # # Print the statistics
print(stats) print(stats)
@ -101,17 +107,14 @@ if is_on_cluster():
np.save(f'final_field_{i}_{rank}.npy', np.save(f'final_field_{i}_{rank}.npy',
final_field.addressable_data(0)) final_field.addressable_data(0))
else: else:
indices = np.arange(len(init.addressable_shards)).reshape( gathered_init = process_allgather(init, tiled=True)
pdims[::-1]).transpose().flatten() gathered_field = process_allgather(field, tiled=True)
print(f"indices {indices}") np.save(f'initial_conditions.npy', gathered_init)
for i in np.arange(len(init.addressable_shards)): np.save(f'field.npy', gathered_field)
np.save(f'initial_conditions_{i}.npy', init.addressable_data(i)) if final_fields is not None:
np.save(f'field_{i}.npy', field.addressable_data(i)) for i, final_field in enumerate(final_fields):
gathered_final_field = process_allgather(final_field, tiled=True)
if final_fields is not None: np.save(f'final_field_{i}.npy', gathered_final_field)
for j, final_field in enumerate(final_fields):
np.save(f'final_field_{j}_{i}.npy',
final_field.addressable_data(i))
print(f"Finished!!") print(f"Finished!!")

View file

@ -1,145 +1,144 @@
import os import os
from math import prod from math import prod
setup_done = False setup_done = False
on_cluster = False on_cluster = False
def is_on_cluster(): def is_on_cluster():
global on_cluster global on_cluster
return on_cluster return on_cluster
def initialize_distributed(): def initialize_distributed():
global setup_done global setup_done
global on_cluster global on_cluster
if not setup_done: if not setup_done:
if "SLURM_JOB_ID" in os.environ: if "SLURM_JOB_ID" in os.environ:
on_cluster = True on_cluster = True
print("Running on cluster") print("Running on cluster")
import jax import jax
jax.distributed.initialize() jax.distributed.initialize()
setup_done = True setup_done = True
on_cluster = True on_cluster = True
else: else:
print("Running locally") print("Running locally")
setup_done = True setup_done = True
on_cluster = False on_cluster = False
os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" os.environ[
import jax "XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
def compare_sharding(sharding1, sharding2): def compare_sharding(sharding1, sharding2):
from jaxdecomp._src.spmd_ops import get_pdims_from_sharding from jaxdecomp._src.spmd_ops import get_pdims_from_sharding
pdims1 = get_pdims_from_sharding(sharding1) pdims1 = get_pdims_from_sharding(sharding1)
pdims2 = get_pdims_from_sharding(sharding2) pdims2 = get_pdims_from_sharding(sharding2)
pdims1 = pdims1 + (1,) * (3 - len(pdims1)) pdims1 = pdims1 + (1, ) * (3 - len(pdims1))
pdims2 = pdims2 + (1,) * (3 - len(pdims2)) pdims2 = pdims2 + (1, ) * (3 - len(pdims2))
return pdims1 == pdims2 return pdims1 == pdims2
def replace_none_or_zero(value): def replace_none_or_zero(value):
# Replace None or 0 with 1 # Replace None or 0 with 1
return 0 if value is None else value return 0 if value is None else value
def process_slices(slices_tuple): def process_slices(slices_tuple):
start_product = 1 start_product = 1
stop_product = 1 stop_product = 1
for s in slices_tuple: for s in slices_tuple:
# Multiply the start and stop values, replacing None/0 with 1 # Multiply the start and stop values, replacing None/0 with 1
start_product *= replace_none_or_zero(s.start) start_product *= replace_none_or_zero(s.start)
stop_product *= replace_none_or_zero(s.stop) stop_product *= replace_none_or_zero(s.stop)
# Return the sum of the two products # Return the sum of the two products
return int(start_product + stop_product) return int(start_product + stop_product)
def device_arange(pdims): def device_arange(pdims):
import jax import jax
from jax import numpy as jnp from jax import numpy as jnp
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
devices = mesh_utils.create_device_mesh(pdims) devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('z', 'y')) mesh = Mesh(devices.T, axis_names=('z', 'y'))
sharding = NamedSharding(mesh, P('z', 'y')) sharding = NamedSharding(mesh, P('z', 'y'))
def generate_aranged(x): def generate_aranged(x):
x_start = replace_none_or_zero(x[0].start) x_start = replace_none_or_zero(x[0].start)
y_start = replace_none_or_zero(x[1].start) y_start = replace_none_or_zero(x[1].start)
a = jnp.array([[x_start + y_start * pdims[0]]]) a = jnp.array([[x_start + y_start * pdims[0]]])
print(f"index is {x} and value is {a}") print(f"index is {x} and value is {a}")
return a return a
aranged = jax.make_array_from_callback( aranged = jax.make_array_from_callback(mesh.devices.shape,
mesh.devices.shape, sharding, data_callback=generate_aranged) sharding,
data_callback=generate_aranged)
return aranged return aranged
def create_ones_spmd_array(global_shape, pdims): def create_ones_spmd_array(global_shape, pdims):
import jax import jax
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
size = jax.device_count() size = jax.device_count()
assert (len(global_shape) == 3) assert (len(global_shape) == 3)
assert (len(pdims) == 2) assert (len(pdims) == 2)
assert (prod(pdims) == size assert (
), "The product of pdims must be equal to the number of MPI processes" prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]) global_shape[2])
# Remap to the global array from the local slice # Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims) devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('z', 'y')) mesh = Mesh(devices.T, axis_names=('z', 'y'))
sharding = NamedSharding(mesh, P('z', 'y')) sharding = NamedSharding(mesh, P('z', 'y'))
global_array = jax.make_array_from_callback( global_array = jax.make_array_from_callback(
global_shape, global_shape,
sharding, sharding,
data_callback=lambda _: jax.numpy.ones(local_shape)) data_callback=lambda _: jax.numpy.ones(local_shape))
return global_array, mesh return global_array, mesh
# Helper function to create a 3D array and remap it to the global array # Helper function to create a 3D array and remap it to the global array
def create_spmd_array(global_shape, pdims): def create_spmd_array(global_shape, pdims):
import jax import jax
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
size = jax.device_count() size = jax.device_count()
assert (len(global_shape) == 3) assert (len(global_shape) == 3)
assert (len(pdims) == 2) assert (len(pdims) == 2)
assert (prod(pdims) == size assert (
), "The product of pdims must be equal to the number of MPI processes" prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
global_shape[2]) global_shape[2])
# Remap to the global array from the local slicei # Remap to the global array from the local slicei
devices = mesh_utils.create_device_mesh(pdims) devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('z', 'y')) mesh = Mesh(devices.T, axis_names=('z', 'y'))
sharding = NamedSharding(mesh, P('z', 'y')) sharding = NamedSharding(mesh, P('z', 'y'))
global_array = jax.make_array_from_callback( global_array = jax.make_array_from_callback(
global_shape, global_shape,
sharding, sharding,
data_callback=lambda x: jax.random.normal( data_callback=lambda x: jax.random.normal(
jax.random.PRNGKey(process_slices(x)), local_shape)) jax.random.PRNGKey(process_slices(x)), local_shape))
return global_array, mesh
return global_array, mesh