mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-07-12 06:43:04 +00:00
update demo script
This commit is contained in:
parent
105568e8db
commit
4d944f01f2
2 changed files with 126 additions and 124 deletions
|
@ -11,11 +11,12 @@ size = jax.device_count()
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffrax import Dopri5,LeapfrogMidpoint, ODETerm, ConstantStepSize, SaveAt, diffeqsolve
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
|
SaveAt, diffeqsolve)
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
from jax.experimental.multihost_utils import process_allgather
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
from jaxpm.painting import cic_paint_dx
|
from jaxpm.painting import cic_paint_dx
|
||||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||||
|
@ -26,8 +27,10 @@ box_size = [float(size)] * 3
|
||||||
snapshots = jnp.linspace(0.1, 1., 4)
|
snapshots = jnp.linspace(0.1, 1., 4)
|
||||||
halo_size = 32
|
halo_size = 32
|
||||||
pdims = (1, 1)
|
pdims = (1, 1)
|
||||||
|
mesh = None
|
||||||
|
sharding = None
|
||||||
if jax.device_count() > 1:
|
if jax.device_count() > 1:
|
||||||
pdims = (8, 1)
|
pdims = (2, 4)
|
||||||
devices = mesh_utils.create_device_mesh(pdims)
|
devices = mesh_utils.create_device_mesh(pdims)
|
||||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
|
@ -40,19 +43,27 @@ def run_simulation(omega_c, sigma8):
|
||||||
pk = jc.power.linear_matter_power(
|
pk = jc.power.linear_matter_power(
|
||||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||||
|
|
||||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk)
|
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
||||||
# Create initial conditions
|
# Create initial conditions
|
||||||
initial_conditions = linear_field(mesh_shape,
|
initial_conditions = linear_field(mesh_shape,
|
||||||
box_size,
|
box_size,
|
||||||
pk_fn,
|
pk_fn,
|
||||||
|
sharding=sharding,
|
||||||
seed=jax.random.PRNGKey(0))
|
seed=jax.random.PRNGKey(0))
|
||||||
|
|
||||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||||
|
|
||||||
# Initial displacement
|
# Initial displacement
|
||||||
dx, p, _ = lpt(cosmo, initial_conditions, 0.1, halo_size=halo_size)
|
dx, p, _ = lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
0.1,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
return initial_conditions, cic_paint_dx(dx,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding), None, None
|
||||||
# Evolve the simulation forward
|
# Evolve the simulation forward
|
||||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size)
|
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
||||||
term = ODETerm(
|
term = ODETerm(
|
||||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||||
solver = LeapfrogMidpoint()
|
solver = LeapfrogMidpoint()
|
||||||
|
@ -70,22 +81,17 @@ def run_simulation(omega_c, sigma8):
|
||||||
|
|
||||||
# Return the simulation volume at requested
|
# Return the simulation volume at requested
|
||||||
states = res.ys
|
states = res.ys
|
||||||
field = cic_paint_dx(dx, halo_size=halo_size)
|
field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
|
||||||
final_fields = [
|
final_fields = [
|
||||||
cic_paint_dx(state[0], halo_size=halo_size) for state in states
|
cic_paint_dx(state[0], halo_size=halo_size, sharding=sharding)
|
||||||
|
for state in states
|
||||||
]
|
]
|
||||||
|
|
||||||
return initial_conditions, field, final_fields, res.stats
|
return initial_conditions, field, final_fields, res.stats
|
||||||
|
|
||||||
|
|
||||||
# Run the simulation
|
# Run the simulation
|
||||||
print(f"mesh {mesh}")
|
init, field, final_fields, stats = run_simulation(0.32, 0.8)
|
||||||
if jax.device_count() > 1:
|
|
||||||
with mesh:
|
|
||||||
init, field, final_fields, stats = run_simulation(0.32, 0.8)
|
|
||||||
|
|
||||||
else:
|
|
||||||
init, field, final_fields, stats = run_simulation(0.32, 0.8)
|
|
||||||
|
|
||||||
# # Print the statistics
|
# # Print the statistics
|
||||||
print(stats)
|
print(stats)
|
||||||
|
@ -101,17 +107,14 @@ if is_on_cluster():
|
||||||
np.save(f'final_field_{i}_{rank}.npy',
|
np.save(f'final_field_{i}_{rank}.npy',
|
||||||
final_field.addressable_data(0))
|
final_field.addressable_data(0))
|
||||||
else:
|
else:
|
||||||
indices = np.arange(len(init.addressable_shards)).reshape(
|
gathered_init = process_allgather(init, tiled=True)
|
||||||
pdims[::-1]).transpose().flatten()
|
gathered_field = process_allgather(field, tiled=True)
|
||||||
print(f"indices {indices}")
|
np.save(f'initial_conditions.npy', gathered_init)
|
||||||
for i in np.arange(len(init.addressable_shards)):
|
np.save(f'field.npy', gathered_field)
|
||||||
|
|
||||||
np.save(f'initial_conditions_{i}.npy', init.addressable_data(i))
|
if final_fields is not None:
|
||||||
np.save(f'field_{i}.npy', field.addressable_data(i))
|
for i, final_field in enumerate(final_fields):
|
||||||
|
gathered_final_field = process_allgather(final_field, tiled=True)
|
||||||
if final_fields is not None:
|
np.save(f'final_field_{i}.npy', gathered_final_field)
|
||||||
for j, final_field in enumerate(final_fields):
|
|
||||||
np.save(f'final_field_{j}_{i}.npy',
|
|
||||||
final_field.addressable_data(i))
|
|
||||||
|
|
||||||
print(f"Finished!!")
|
print(f"Finished!!")
|
||||||
|
|
|
@ -1,145 +1,144 @@
|
||||||
import os
|
import os
|
||||||
from math import prod
|
from math import prod
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
setup_done = False
|
setup_done = False
|
||||||
on_cluster = False
|
on_cluster = False
|
||||||
|
|
||||||
|
|
||||||
def is_on_cluster():
|
def is_on_cluster():
|
||||||
global on_cluster
|
global on_cluster
|
||||||
return on_cluster
|
return on_cluster
|
||||||
|
|
||||||
|
|
||||||
def initialize_distributed():
|
def initialize_distributed():
|
||||||
global setup_done
|
global setup_done
|
||||||
global on_cluster
|
global on_cluster
|
||||||
if not setup_done:
|
if not setup_done:
|
||||||
if "SLURM_JOB_ID" in os.environ:
|
if "SLURM_JOB_ID" in os.environ:
|
||||||
on_cluster = True
|
on_cluster = True
|
||||||
print("Running on cluster")
|
print("Running on cluster")
|
||||||
import jax
|
import jax
|
||||||
jax.distributed.initialize()
|
jax.distributed.initialize()
|
||||||
setup_done = True
|
setup_done = True
|
||||||
on_cluster = True
|
on_cluster = True
|
||||||
else:
|
else:
|
||||||
print("Running locally")
|
print("Running locally")
|
||||||
setup_done = True
|
setup_done = True
|
||||||
on_cluster = False
|
on_cluster = False
|
||||||
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||||
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
os.environ[
|
||||||
import jax
|
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||||
|
import jax
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compare_sharding(sharding1, sharding2):
|
def compare_sharding(sharding1, sharding2):
|
||||||
from jaxdecomp._src.spmd_ops import get_pdims_from_sharding
|
from jaxdecomp._src.spmd_ops import get_pdims_from_sharding
|
||||||
pdims1 = get_pdims_from_sharding(sharding1)
|
pdims1 = get_pdims_from_sharding(sharding1)
|
||||||
pdims2 = get_pdims_from_sharding(sharding2)
|
pdims2 = get_pdims_from_sharding(sharding2)
|
||||||
pdims1 = pdims1 + (1,) * (3 - len(pdims1))
|
pdims1 = pdims1 + (1, ) * (3 - len(pdims1))
|
||||||
pdims2 = pdims2 + (1,) * (3 - len(pdims2))
|
pdims2 = pdims2 + (1, ) * (3 - len(pdims2))
|
||||||
return pdims1 == pdims2
|
return pdims1 == pdims2
|
||||||
|
|
||||||
|
|
||||||
def replace_none_or_zero(value):
|
def replace_none_or_zero(value):
|
||||||
# Replace None or 0 with 1
|
# Replace None or 0 with 1
|
||||||
return 0 if value is None else value
|
return 0 if value is None else value
|
||||||
|
|
||||||
|
|
||||||
def process_slices(slices_tuple):
|
def process_slices(slices_tuple):
|
||||||
|
|
||||||
start_product = 1
|
start_product = 1
|
||||||
stop_product = 1
|
stop_product = 1
|
||||||
|
|
||||||
for s in slices_tuple:
|
for s in slices_tuple:
|
||||||
# Multiply the start and stop values, replacing None/0 with 1
|
# Multiply the start and stop values, replacing None/0 with 1
|
||||||
start_product *= replace_none_or_zero(s.start)
|
start_product *= replace_none_or_zero(s.start)
|
||||||
stop_product *= replace_none_or_zero(s.stop)
|
stop_product *= replace_none_or_zero(s.stop)
|
||||||
|
|
||||||
# Return the sum of the two products
|
# Return the sum of the two products
|
||||||
return int(start_product + stop_product)
|
return int(start_product + stop_product)
|
||||||
|
|
||||||
|
|
||||||
def device_arange(pdims):
|
def device_arange(pdims):
|
||||||
import jax
|
import jax
|
||||||
from jax import numpy as jnp
|
from jax import numpy as jnp
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
devices = mesh_utils.create_device_mesh(pdims)
|
devices = mesh_utils.create_device_mesh(pdims)
|
||||||
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
||||||
sharding = NamedSharding(mesh, P('z', 'y'))
|
sharding = NamedSharding(mesh, P('z', 'y'))
|
||||||
|
|
||||||
def generate_aranged(x):
|
def generate_aranged(x):
|
||||||
x_start = replace_none_or_zero(x[0].start)
|
x_start = replace_none_or_zero(x[0].start)
|
||||||
y_start = replace_none_or_zero(x[1].start)
|
y_start = replace_none_or_zero(x[1].start)
|
||||||
a = jnp.array([[x_start + y_start * pdims[0]]])
|
a = jnp.array([[x_start + y_start * pdims[0]]])
|
||||||
print(f"index is {x} and value is {a}")
|
print(f"index is {x} and value is {a}")
|
||||||
return a
|
return a
|
||||||
|
|
||||||
aranged = jax.make_array_from_callback(
|
aranged = jax.make_array_from_callback(mesh.devices.shape,
|
||||||
mesh.devices.shape, sharding, data_callback=generate_aranged)
|
sharding,
|
||||||
|
data_callback=generate_aranged)
|
||||||
|
|
||||||
return aranged
|
return aranged
|
||||||
|
|
||||||
|
|
||||||
def create_ones_spmd_array(global_shape, pdims):
|
def create_ones_spmd_array(global_shape, pdims):
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
size = jax.device_count()
|
size = jax.device_count()
|
||||||
assert (len(global_shape) == 3)
|
assert (len(global_shape) == 3)
|
||||||
assert (len(pdims) == 2)
|
assert (len(pdims) == 2)
|
||||||
assert (prod(pdims) == size
|
assert (
|
||||||
), "The product of pdims must be equal to the number of MPI processes"
|
prod(pdims) == size
|
||||||
|
), "The product of pdims must be equal to the number of MPI processes"
|
||||||
|
|
||||||
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
|
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
|
||||||
global_shape[2])
|
global_shape[2])
|
||||||
|
|
||||||
# Remap to the global array from the local slice
|
# Remap to the global array from the local slice
|
||||||
devices = mesh_utils.create_device_mesh(pdims)
|
devices = mesh_utils.create_device_mesh(pdims)
|
||||||
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
||||||
sharding = NamedSharding(mesh, P('z', 'y'))
|
sharding = NamedSharding(mesh, P('z', 'y'))
|
||||||
global_array = jax.make_array_from_callback(
|
global_array = jax.make_array_from_callback(
|
||||||
global_shape,
|
global_shape,
|
||||||
sharding,
|
sharding,
|
||||||
data_callback=lambda _: jax.numpy.ones(local_shape))
|
data_callback=lambda _: jax.numpy.ones(local_shape))
|
||||||
|
|
||||||
return global_array, mesh
|
return global_array, mesh
|
||||||
|
|
||||||
|
|
||||||
# Helper function to create a 3D array and remap it to the global array
|
# Helper function to create a 3D array and remap it to the global array
|
||||||
def create_spmd_array(global_shape, pdims):
|
def create_spmd_array(global_shape, pdims):
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
from jax.experimental import mesh_utils
|
from jax.experimental import mesh_utils
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
size = jax.device_count()
|
size = jax.device_count()
|
||||||
assert (len(global_shape) == 3)
|
assert (len(global_shape) == 3)
|
||||||
assert (len(pdims) == 2)
|
assert (len(pdims) == 2)
|
||||||
assert (prod(pdims) == size
|
assert (
|
||||||
), "The product of pdims must be equal to the number of MPI processes"
|
prod(pdims) == size
|
||||||
|
), "The product of pdims must be equal to the number of MPI processes"
|
||||||
|
|
||||||
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
|
local_shape = (global_shape[0] // pdims[1], global_shape[1] // pdims[0],
|
||||||
global_shape[2])
|
global_shape[2])
|
||||||
|
|
||||||
# Remap to the global array from the local slicei
|
# Remap to the global array from the local slicei
|
||||||
devices = mesh_utils.create_device_mesh(pdims)
|
devices = mesh_utils.create_device_mesh(pdims)
|
||||||
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
mesh = Mesh(devices.T, axis_names=('z', 'y'))
|
||||||
sharding = NamedSharding(mesh, P('z', 'y'))
|
sharding = NamedSharding(mesh, P('z', 'y'))
|
||||||
global_array = jax.make_array_from_callback(
|
global_array = jax.make_array_from_callback(
|
||||||
global_shape,
|
global_shape,
|
||||||
sharding,
|
sharding,
|
||||||
data_callback=lambda x: jax.random.normal(
|
data_callback=lambda x: jax.random.normal(
|
||||||
jax.random.PRNGKey(process_slices(x)), local_shape))
|
jax.random.PRNGKey(process_slices(x)), local_shape))
|
||||||
|
|
||||||
return global_array, mesh
|
|
||||||
|
|
||||||
|
return global_array, mesh
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue