make distributed pm work in single controller

This commit is contained in:
Wassim KABALAN 2024-10-21 14:03:48 -04:00
parent 9c94f994ff
commit 375f2048e4
2 changed files with 176 additions and 13 deletions

View file

@ -1,12 +1,12 @@
import os
from distributed_utils import initialize_distributed, is_on_cluster
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
initialize_distributed()
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
size = jax.device_count()
import jax.numpy as jnp
import jax_cosmo as jc
@ -24,9 +24,9 @@ size = 256
mesh_shape = [size] * 3
box_size = [float(size)] * 3
snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 64
halo_size = 32
pdims = (1, 1)
if jax.device_count() > 1:
pdims = (4, 2)
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
@ -51,7 +51,8 @@ def run_simulation(omega_c, sigma8):
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
return initial_conditions, cic_paint_dx(dx,
halo_size=halo_size), None, None
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
term = ODETerm(
@ -80,6 +81,7 @@ def run_simulation(omega_c, sigma8):
# Run the simulation
print(f"mesh {mesh}")
if jax.device_count() > 1:
with mesh:
init, field, final_fields, stats = run_simulation(0.32, 0.8)
@ -89,13 +91,29 @@ else:
# # Print the statistics
print(stats)
print(f"done now saving")
if is_on_cluster():
rank = jax.process_index()
# # save the final state
np.save(f'initial_conditions_{rank}.npy', init.addressable_data(0))
np.save(f'field_{rank}.npy', field.addressable_data(0))
# # save the final state
np.save(f'initial_conditions_{rank}.npy', init.addressable_data(0))
np.save(f'field_{rank}.npy', field.addressable_data(0))
if final_fields is not None:
for i, final_field in enumerate(final_fields):
np.save(f'final_field_{i}_{rank}.npy',
final_field.addressable_data(0))
else:
indices = np.arange(len(init.addressable_shards)).reshape(
pdims[::-1]).transpose().flatten()
print(f"indices {indices}")
for i in np.arange(len(init.addressable_shards)):
if final_fields is not None:
for i, final_field in enumerate(final_fields):
np.save(f'final_field_{i}_{rank}.npy', final_field.addressable_data(0))
np.save(f'initial_conditions_{i}.npy', init.addressable_data(i))
np.save(f'field_{i}.npy', field.addressable_data(i))
if final_fields is not None:
for j, final_field in enumerate(final_fields):
np.save(f'final_field_{j}_{i}.npy',
final_field.addressable_data(i))
print(f"Finished!!")