update demo script

This commit is contained in:
Wassim KABALAN 2024-10-22 12:10:18 -04:00
parent 105568e8db
commit 4d944f01f2
2 changed files with 126 additions and 124 deletions

View file

@ -11,11 +11,12 @@ size = jax.device_count()
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
SaveAt, diffeqsolve)
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.multihost_utils import process_allgather
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
@ -26,8 +27,10 @@ box_size = [float(size)] * 3
snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 32
pdims = (1, 1)
mesh = None
sharding = None
if jax.device_count() > 1:
pdims = (8, 1)
pdims = (2, 4)
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
@ -40,19 +43,27 @@ def run_simulation(omega_c, sigma8):
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
sharding=sharding,
seed=jax.random.PRNGKey(0))
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
dx, p, _ = lpt(cosmo,
initial_conditions,
0.1,
halo_size=halo_size,
sharding=sharding)
return initial_conditions, cic_paint_dx(dx,
halo_size=halo_size,
sharding=sharding), None, None
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = LeapfrogMidpoint()
@ -70,21 +81,16 @@ def run_simulation(omega_c, sigma8):
# Return the simulation volume at requested
states = res.ys
field = cic_paint_dx(dx, halo_size=halo_size)
field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
final_fields = [
cic_paint_dx(state[0], halo_size=halo_size) for state in states
cic_paint_dx(state[0], halo_size=halo_size, sharding=sharding)
for state in states
]
return initial_conditions, field, final_fields, res.stats
# Run the simulation
print(f"mesh {mesh}")
if jax.device_count() > 1:
with mesh:
init, field, final_fields, stats = run_simulation(0.32, 0.8)
else:
init, field, final_fields, stats = run_simulation(0.32, 0.8)
# # Print the statistics
@ -101,17 +107,14 @@ if is_on_cluster():
np.save(f'final_field_{i}_{rank}.npy',
final_field.addressable_data(0))
else:
indices = np.arange(len(init.addressable_shards)).reshape(
pdims[::-1]).transpose().flatten()
print(f"indices {indices}")
for i in np.arange(len(init.addressable_shards)):
np.save(f'initial_conditions_{i}.npy', init.addressable_data(i))
np.save(f'field_{i}.npy', field.addressable_data(i))
gathered_init = process_allgather(init, tiled=True)
gathered_field = process_allgather(field, tiled=True)
np.save(f'initial_conditions.npy', gathered_init)
np.save(f'field.npy', gathered_field)
if final_fields is not None:
for j, final_field in enumerate(final_fields):
np.save(f'final_field_{j}_{i}.npy',
final_field.addressable_data(i))
for i, final_field in enumerate(final_fields):
gathered_final_field = process_allgather(final_field, tiled=True)
np.save(f'final_field_{i}.npy', gathered_final_field)
print(f"Finished!!")

View file

@ -1,8 +1,6 @@
import os
from math import prod
setup_done = False
on_cluster = False
@ -28,12 +26,11 @@ def initialize_distributed():
setup_done = True
on_cluster = False
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
os.environ[
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax
def compare_sharding(sharding1, sharding2):
from jaxdecomp._src.spmd_ops import get_pdims_from_sharding
pdims1 = get_pdims_from_sharding(sharding1)
@ -80,8 +77,9 @@ def device_arange(pdims):
print(f"index is {x} and value is {a}")
return a
aranged = jax.make_array_from_callback(
mesh.devices.shape, sharding, data_callback=generate_aranged)
aranged = jax.make_array_from_callback(mesh.devices.shape,
sharding,
data_callback=generate_aranged)
return aranged
@ -96,7 +94,8 @@ def create_ones_spmd_array(global_shape, pdims):
size = jax.device_count()
assert (len(global_shape) == 3)
assert (len(pdims) == 2)
assert (prod(pdims) == size
assert (
prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
@ -125,7 +124,8 @@ def create_spmd_array(global_shape, pdims):
size = jax.device_count()
assert (len(global_shape) == 3)
assert (len(pdims) == 2)
assert (prod(pdims) == size
assert (
prod(pdims) == size
), "The product of pdims must be equal to the number of MPI processes"
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
@ -142,4 +142,3 @@ def create_spmd_array(global_shape, pdims):
jax.random.PRNGKey(process_slices(x)), local_shape))
return global_array, mesh