mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
update example notebooks
This commit is contained in:
parent
4da4c66472
commit
72457d6c37
10 changed files with 2059 additions and 742 deletions
File diff suppressed because one or more lines are too long
418
notebooks/02-Advanced_usage.ipynb
Normal file
418
notebooks/02-Advanced_usage.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
693
notebooks/03-MultiGPU_PM_Halo.ipynb
Normal file
693
notebooks/03-MultiGPU_PM_Halo.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,106 +0,0 @@
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
|
||||||
import jax
|
|
||||||
|
|
||||||
jax.distributed.initialize()
|
|
||||||
rank = jax.process_index()
|
|
||||||
size = jax.process_count()
|
|
||||||
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import jax_cosmo as jc
|
|
||||||
import numpy as np
|
|
||||||
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
|
|
||||||
diffeqsolve)
|
|
||||||
from jax.experimental.mesh_utils import create_device_mesh
|
|
||||||
from jax.experimental.multihost_utils import process_allgather
|
|
||||||
from jax.sharding import Mesh, NamedSharding
|
|
||||||
from jax.sharding import PartitionSpec as P
|
|
||||||
|
|
||||||
from jaxpm.kernels import interpolate_power_spectrum
|
|
||||||
from jaxpm.painting import cic_paint_dx
|
|
||||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
|
||||||
|
|
||||||
all_gather = partial(process_allgather, tiled=True)
|
|
||||||
|
|
||||||
pdims = (2, 4)
|
|
||||||
devices = create_device_mesh(pdims)
|
|
||||||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
|
||||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
|
||||||
|
|
||||||
mesh_shape = [512, 512, 512]
|
|
||||||
box_size = [500., 500., 1000.]
|
|
||||||
halo_size = 64
|
|
||||||
snapshots = jnp.linspace(0.1, 1., 2)
|
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def run_simulation(omega_c, sigma8):
|
|
||||||
# Create a small function to generate the matter power spectrum
|
|
||||||
k = jnp.logspace(-4, 1, 128)
|
|
||||||
pk = jc.power.linear_matter_power(
|
|
||||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
|
||||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
|
||||||
|
|
||||||
# Create initial conditions
|
|
||||||
initial_conditions = linear_field(mesh_shape,
|
|
||||||
box_size,
|
|
||||||
pk_fn,
|
|
||||||
seed=jax.random.PRNGKey(0),
|
|
||||||
sharding=sharding)
|
|
||||||
|
|
||||||
# Create particles
|
|
||||||
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
|
|
||||||
axis=-1).reshape([-1, 3])
|
|
||||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
|
||||||
|
|
||||||
# Initial displacement
|
|
||||||
dx, p, _ = lpt(cosmo,
|
|
||||||
initial_conditions,
|
|
||||||
a=0.1,
|
|
||||||
halo_size=halo_size,
|
|
||||||
sharding=sharding)
|
|
||||||
|
|
||||||
# Evolve the simulation forward
|
|
||||||
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
|
||||||
term = ODETerm(
|
|
||||||
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
|
||||||
solver = LeapfrogMidpoint()
|
|
||||||
|
|
||||||
stepsize_controller = ConstantStepSize()
|
|
||||||
res = diffeqsolve(term,
|
|
||||||
solver,
|
|
||||||
t0=0.1,
|
|
||||||
t1=1.,
|
|
||||||
dt0=0.01,
|
|
||||||
y0=jnp.stack([dx, p], axis=0),
|
|
||||||
args=cosmo,
|
|
||||||
saveat=SaveAt(ts=snapshots),
|
|
||||||
stepsize_controller=stepsize_controller)
|
|
||||||
|
|
||||||
return initial_conditions, dx, res.ys, res.stats
|
|
||||||
|
|
||||||
|
|
||||||
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
|
||||||
0.25, 0.8)
|
|
||||||
print(f"[{rank}] Simulation completed")
|
|
||||||
print(f"[{rank}] Solver stats: {solver_stats}")
|
|
||||||
|
|
||||||
# Gather the results
|
|
||||||
|
|
||||||
pm_dict = {
|
|
||||||
"initial_conditions": all_gather(initial_conditions),
|
|
||||||
"lpt_displacements": all_gather(lpt_displacements),
|
|
||||||
"solver_stats": solver_stats
|
|
||||||
}
|
|
||||||
|
|
||||||
for i in range(len(ode_solutions)):
|
|
||||||
sol = ode_solutions[i]
|
|
||||||
pm_dict[f"ode_solution_{i}"] = all_gather(sol)
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
np.savez("multihost_pm.npz", **pm_dict)
|
|
||||||
|
|
||||||
print(f"[{rank}] Simulation results saved")
|
|
383
notebooks/04-MultiGPU_PM_Solvers.ipynb
Normal file
383
notebooks/04-MultiGPU_PM_Solvers.ipynb
Normal file
File diff suppressed because one or more lines are too long
287
notebooks/05-MultiHost_PM.ipynb
Normal file
287
notebooks/05-MultiHost_PM.ipynb
Normal file
File diff suppressed because one or more lines are too long
171
notebooks/05-MultiHost_PM.py
Normal file
171
notebooks/05-MultiHost_PM.py
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
|
||||||
|
import jax
|
||||||
|
|
||||||
|
jax.distributed.initialize()
|
||||||
|
rank = jax.process_index()
|
||||||
|
size = jax.process_count()
|
||||||
|
if rank == 0:
|
||||||
|
print(f"SIZE is {jax.device_count()}")
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax_cosmo as jc
|
||||||
|
import numpy as np
|
||||||
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
|
PIDController, SaveAt, diffeqsolve)
|
||||||
|
from jax.experimental.mesh_utils import create_device_mesh
|
||||||
|
from jax.experimental.multihost_utils import (process_allgather,
|
||||||
|
sync_global_devices)
|
||||||
|
from jax.sharding import Mesh, NamedSharding
|
||||||
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
from jaxpm.kernels import interpolate_power_spectrum
|
||||||
|
from jaxpm.painting import cic_paint_dx
|
||||||
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||||
|
|
||||||
|
all_gather = partial(process_allgather, tiled=True)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run a cosmological simulation with JAX.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--pdims",
|
||||||
|
type=int,
|
||||||
|
nargs=2,
|
||||||
|
default=[1, jax.devices()],
|
||||||
|
help="Processor grid dimensions as two integers (e.g., 2 4).")
|
||||||
|
parser.add_argument(
|
||||||
|
"-m",
|
||||||
|
"--mesh_shape",
|
||||||
|
type=int,
|
||||||
|
nargs=3,
|
||||||
|
default=[512, 512, 512],
|
||||||
|
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-b",
|
||||||
|
"--box_size",
|
||||||
|
type=float,
|
||||||
|
nargs=3,
|
||||||
|
default=[500.0, 500.0, 500.0],
|
||||||
|
help=
|
||||||
|
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
|
||||||
|
)
|
||||||
|
parser.add_argument("-H",
|
||||||
|
"--halo_size",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="Halo size for the simulation.")
|
||||||
|
parser.add_argument("-s",
|
||||||
|
"--solver",
|
||||||
|
type=str,
|
||||||
|
choices=['leapfrog', 'dopri8'],
|
||||||
|
default='leapfrog',
|
||||||
|
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def create_mesh_and_sharding(mesh_shape, pdims):
|
||||||
|
devices = create_device_mesh(pdims)
|
||||||
|
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||||
|
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||||
|
return mesh, sharding
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
|
||||||
|
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
|
||||||
|
solver_choice, sharding):
|
||||||
|
k = jnp.logspace(-4, 1, 128)
|
||||||
|
pk = jc.power.linear_matter_power(
|
||||||
|
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||||
|
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
||||||
|
|
||||||
|
initial_conditions = linear_field(mesh_shape,
|
||||||
|
box_size,
|
||||||
|
pk_fn,
|
||||||
|
seed=jax.random.PRNGKey(0),
|
||||||
|
sharding=sharding)
|
||||||
|
|
||||||
|
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
|
||||||
|
axis=-1).reshape([-1, 3])
|
||||||
|
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||||
|
|
||||||
|
dx, p, _ = lpt(cosmo,
|
||||||
|
initial_conditions,
|
||||||
|
a=0.1,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding)
|
||||||
|
|
||||||
|
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
|
||||||
|
term = ODETerm(
|
||||||
|
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
|
||||||
|
|
||||||
|
# Choose solver
|
||||||
|
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
|
||||||
|
stepsize_controller = ConstantStepSize(
|
||||||
|
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
|
||||||
|
res = diffeqsolve(term,
|
||||||
|
solver,
|
||||||
|
t0=0.1,
|
||||||
|
t1=1.,
|
||||||
|
dt0=0.01,
|
||||||
|
y0=jnp.stack([dx, p], axis=0),
|
||||||
|
args=cosmo,
|
||||||
|
saveat=SaveAt(ts=jnp.array([0.5, 1.0])),
|
||||||
|
stepsize_controller=stepsize_controller)
|
||||||
|
|
||||||
|
ode_fields = [
|
||||||
|
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
|
||||||
|
for sol in res.ys
|
||||||
|
]
|
||||||
|
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
|
||||||
|
return initial_conditions, lpt_field, ode_fields, res.stats
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_arguments()
|
||||||
|
mesh_shape = args.mesh_shape
|
||||||
|
box_size = args.box_size
|
||||||
|
halo_size = args.halo_size
|
||||||
|
solver_choice = args.solver
|
||||||
|
|
||||||
|
mesh, sharding = create_mesh_and_sharding(mesh_shape, args.pdims)
|
||||||
|
|
||||||
|
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
|
||||||
|
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
|
||||||
|
solver_choice, sharding)
|
||||||
|
|
||||||
|
# Save initial conditions
|
||||||
|
initial_conditions_g = all_gather(initial_conditions)
|
||||||
|
if rank == 0:
|
||||||
|
print(f"[{rank}] Saving initial_conditions")
|
||||||
|
np.save("initial_conditions.npy", initial_conditions_g)
|
||||||
|
print(f"[{rank}] initial_conditions saved")
|
||||||
|
del initial_conditions_g, initial_conditions
|
||||||
|
|
||||||
|
# Save LPT displacements
|
||||||
|
lpt_displacements_g = all_gather(lpt_displacements)
|
||||||
|
if rank == 0:
|
||||||
|
print(f"[{rank}] Saving lpt_displacements")
|
||||||
|
np.save("lpt_displacements.npy", lpt_displacements_g)
|
||||||
|
print(f"[{rank}] lpt_displacements saved")
|
||||||
|
del lpt_displacements_g, lpt_displacements
|
||||||
|
|
||||||
|
# Save each ODE solution separately
|
||||||
|
for i, sol in enumerate(ode_solutions):
|
||||||
|
sol_g = all_gather(sol)
|
||||||
|
if rank == 0:
|
||||||
|
print(f"[{rank}] Saving ode_solution_{i}")
|
||||||
|
np.save(f"ode_solution_{i}.npy", sol_g)
|
||||||
|
print(f"[{rank}] ode_solution_{i} saved")
|
||||||
|
del sol_g
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -42,32 +42,79 @@ def plot_fields(fields_dict, sum_over=None):
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def plot_fields_single_projection(fields_dict, sum_over=None):
|
def plot_fields_single_projection(fields_dict, sum_over=None, project_axis=0):
|
||||||
"""
|
"""
|
||||||
Plots a single projection (along axis 0) of 3D fields in one row,
|
Plots a single projection (along axis 0) of 3D fields in a grid,
|
||||||
summing over the first `sum_over` elements along the 0-axis.
|
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
||||||
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
||||||
"""
|
"""
|
||||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||||
nb_cols = len(fields_dict)
|
nb_fields = len(fields_dict)
|
||||||
fig, axes = plt.subplots(1, nb_cols, figsize=(5 * nb_cols, 5))
|
nb_cols = 4 # Set number of images per row
|
||||||
|
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(nb_rows,
|
||||||
|
nb_cols,
|
||||||
|
figsize=(5 * nb_cols, 5 * nb_rows))
|
||||||
|
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
|
||||||
|
|
||||||
for i, (name, field) in enumerate(fields_dict.items()):
|
for i, (name, field) in enumerate(fields_dict.items()):
|
||||||
|
row, col = divmod(i, nb_cols)
|
||||||
|
|
||||||
# Define the slice for the 0-axis projection
|
# Define the slice for the 0-axis projection
|
||||||
slicing = [slice(None)] * field.ndim
|
slicing = [slice(None)] * field.ndim
|
||||||
slicing[0] = slice(None, sum_over)
|
slicing[project_axis] = slice(None, sum_over)
|
||||||
slicing = tuple(slicing)
|
slicing = tuple(slicing)
|
||||||
|
|
||||||
# Sum projection over axis 0 and plot
|
# Sum projection over axis 0 and plot
|
||||||
axes[i].imshow(field[slicing].sum(axis=0) + 1,
|
axes[row, col].imshow(field[slicing].sum(axis=project_axis) + 1,
|
||||||
cmap='magma',
|
cmap='magma',
|
||||||
extent=[0, field.shape[1], 0, field.shape[2]])
|
extent=[0, field.shape[1], 0, field.shape[2]])
|
||||||
axes[i].set_xlabel('Mpc/h')
|
axes[row, col].set_xlabel('Mpc/h')
|
||||||
axes[i].set_ylabel('Mpc/h')
|
axes[row, col].set_ylabel('Mpc/h')
|
||||||
axes[i].set_title(f"{name} projection 0")
|
axes[row, col].set_title(f"{name} projection 0")
|
||||||
|
|
||||||
|
# Remove any empty subplots
|
||||||
|
for j in range(i + 1, nb_rows * nb_cols):
|
||||||
|
fig.delaxes(axes.flatten()[j])
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def stack_slices(array):
|
||||||
|
"""
|
||||||
|
Stacks 2D slices of an array into a single array based on provided partition dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- array_slices: a 2D list of array slices (list of lists format) where
|
||||||
|
array_slices[i][j] is the slice located at row i, column j in the grid.
|
||||||
|
- pdims: a tuple representing the grid dimensions (rows, columns).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A single array constructed by stacking the slices.
|
||||||
|
"""
|
||||||
|
# Initialize an empty list to store the vertically stacked rows
|
||||||
|
pdims = array.sharding.mesh.devices.shape
|
||||||
|
|
||||||
|
field_slices = []
|
||||||
|
|
||||||
|
# Iterate over rows in pdims[0]
|
||||||
|
for i in range(pdims[0]):
|
||||||
|
row_slices = []
|
||||||
|
|
||||||
|
# Iterate over columns in pdims[1]
|
||||||
|
for j in range(pdims[1]):
|
||||||
|
slice_index = i * pdims[0] + j
|
||||||
|
row_slices.append(array.addressable_data(slice_index))
|
||||||
|
# Stack the current row of slices vertically
|
||||||
|
stacked_row = np.hstack(row_slices)
|
||||||
|
field_slices.append(stacked_row)
|
||||||
|
|
||||||
|
# Stack all rows horizontally to form the full array
|
||||||
|
full_array = np.vstack(field_slices)
|
||||||
|
|
||||||
|
return full_array
|
||||||
|
|
Loading…
Add table
Reference in a new issue