update demo script

This commit is contained in:
Wassim KABALAN 2024-10-22 12:10:18 -04:00
parent 105568e8db
commit 4d944f01f2
2 changed files with 126 additions and 124 deletions

View file

@ -11,11 +11,12 @@ size = jax.device_count()
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
import numpy as np import numpy as np
from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
SaveAt, diffeqsolve)
from jax.experimental import mesh_utils from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather
from jaxpm.kernels import interpolate_power_spectrum from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn from jaxpm.pm import linear_field, lpt, make_ode_fn
@ -26,8 +27,10 @@ box_size = [float(size)] * 3
snapshots = jnp.linspace(0.1, 1., 4) snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 32 halo_size = 32
pdims = (1, 1) pdims = (1, 1)
mesh = None
sharding = None
if jax.device_count() > 1: if jax.device_count() > 1:
pdims = (8, 1) pdims = (2, 4)
devices = mesh_utils.create_device_mesh(pdims) devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y')) mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
@ -40,19 +43,27 @@ def run_simulation(omega_c, sigma8):
pk = jc.power.linear_matter_power( pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk) pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
# Create initial conditions # Create initial conditions
initial_conditions = linear_field(mesh_shape, initial_conditions = linear_field(mesh_shape,
box_size, box_size,
pk_fn, pk_fn,
sharding=sharding,
seed=jax.random.PRNGKey(0)) seed=jax.random.PRNGKey(0))
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size) dx, p, _ = lpt(cosmo,
initial_conditions,
0.1,
halo_size=halo_size,
sharding=sharding)
return initial_conditions, cic_paint_dx(dx,
halo_size=halo_size,
sharding=sharding), None, None
# Evolve the simulation forward # Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size) ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm( term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = LeapfrogMidpoint() solver = LeapfrogMidpoint()
@ -70,21 +81,16 @@ def run_simulation(omega_c, sigma8):
# Return the simulation volume at requested # Return the simulation volume at requested
states = res.ys states = res.ys
field = cic_paint_dx(dx, halo_size=halo_size) field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
final_fields = [ final_fields = [
cic_paint_dx(state[0], halo_size=halo_size) for state in states cic_paint_dx(state[0], halo_size=halo_size, sharding=sharding)
for state in states
] ]
return initial_conditions, field, final_fields, res.stats return initial_conditions, field, final_fields, res.stats
# Run the simulation # Run the simulation
print(f"mesh {mesh}")
if jax.device_count() > 1:
with mesh:
init, field, final_fields, stats = run_simulation(0.32, 0.8)
else:
init, field, final_fields, stats = run_simulation(0.32, 0.8) init, field, final_fields, stats = run_simulation(0.32, 0.8)
# # Print the statistics # # Print the statistics
@ -101,17 +107,14 @@ if is_on_cluster():
np.save(f'final_field_{i}_{rank}.npy', np.save(f'final_field_{i}_{rank}.npy',
final_field.addressable_data(0)) final_field.addressable_data(0))
else: else:
indices = np.arange(len(init.addressable_shards)).reshape( gathered_init = process_allgather(init, tiled=True)
pdims[::-1]).transpose().flatten() gathered_field = process_allgather(field, tiled=True)
print(f"indices {indices}") np.save(f'initial_conditions.npy', gathered_init)
for i in np.arange(len(init.addressable_shards)): np.save(f'field.npy', gathered_field)
np.save(f'initial_conditions_{i}.npy', init.addressable_data(i))
np.save(f'field_{i}.npy', field.addressable_data(i))
if final_fields is not None: if final_fields is not None:
for j, final_field in enumerate(final_fields): for i, final_field in enumerate(final_fields):
np.save(f'final_field_{j}_{i}.npy', gathered_final_field = process_allgather(final_field, tiled=True)
final_field.addressable_data(i)) np.save(f'final_field_{i}.npy', gathered_final_field)
print(f"Finished!!") print(f"Finished!!")

View file

@ -1,8 +1,6 @@
import os import os
from math import prod from math import prod
setup_done = False setup_done = False
on_cluster = False on_cluster = False
@ -28,12 +26,11 @@ def initialize_distributed():
setup_done = True setup_done = True
on_cluster = False on_cluster = False
os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" os.environ[
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax import jax
def compare_sharding(sharding1, sharding2): def compare_sharding(sharding1, sharding2):
from jaxdecomp._src.spmd_ops import get_pdims_from_sharding from jaxdecomp._src.spmd_ops import get_pdims_from_sharding
pdims1 = get_pdims_from_sharding(sharding1) pdims1 = get_pdims_from_sharding(sharding1)
@ -80,8 +77,9 @@ def device_arange(pdims):
print(f"index is {x} and value is {a}") print(f"index is {x} and value is {a}")
return a return a
aranged = jax.make_array_from_callback( aranged = jax.make_array_from_callback(mesh.devices.shape,
mesh.devices.shape, sharding, data_callback=generate_aranged) sharding,
data_callback=generate_aranged)
return aranged return aranged
@ -96,7 +94,8 @@ def create_ones_spmd_array(global_shape, pdims):
size = jax.device_count() size = jax.device_count()
assert (len(global_shape) == 3) assert (len(global_shape) == 3)
assert (len(pdims) == 2) assert (len(pdims) == 2)
assert (prod(pdims) == size assert (
prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes" ), "The product of pdims must be equal to the number of MPI processes"
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
@ -125,7 +124,8 @@ def create_spmd_array(global_shape, pdims):
size = jax.device_count() size = jax.device_count()
assert (len(global_shape) == 3) assert (len(global_shape) == 3)
assert (len(pdims) == 2) assert (len(pdims) == 2)
assert (prod(pdims) == size assert (
prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes" ), "The product of pdims must be equal to the number of MPI processes"
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0], local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
@ -142,4 +142,3 @@ def create_spmd_array(global_shape, pdims):
jax.random.PRNGKey(process_slices(x)), local_shape)) jax.random.PRNGKey(process_slices(x)), local_shape))
return global_array, mesh return global_array, mesh