update example notebooks

This commit is contained in:
Wassim KABALAN 2024-10-30 01:58:37 +01:00
parent 4da4c66472
commit 72457d6c37
10 changed files with 2059 additions and 742 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,106 +0,0 @@
import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
all_gather = partial(process_allgather, tiled=True)
pdims = (2, 4)
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
mesh_shape = [512, 512, 512]
box_size = [500., 500., 1000.]
halo_size = 64
snapshots = jnp.linspace(0.1, 1., 2)
@jax.jit
def run_simulation(omega_c, sigma8):
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
# Create initial conditions
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
# Create particles
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
axis=-1).reshape([-1, 3])
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
halo_size=halo_size,
sharding=sharding)
# Evolve the simulation forward
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
solver = LeapfrogMidpoint()
stepsize_controller = ConstantStepSize()
res = diffeqsolve(term,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(ts=snapshots),
stepsize_controller=stepsize_controller)
return initial_conditions, dx, res.ys, res.stats
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
0.25, 0.8)
print(f"[{rank}] Simulation completed")
print(f"[{rank}] Solver stats: {solver_stats}")
# Gather the results
pm_dict = {
"initial_conditions": all_gather(initial_conditions),
"lpt_displacements": all_gather(lpt_displacements),
"solver_stats": solver_stats
}
for i in range(len(ode_solutions)):
sol = ode_solutions[i]
pm_dict[f"ode_solution_{i}"] = all_gather(sol)
if rank == 0:
np.savez("multihost_pm.npz", **pm_dict)
print(f"[{rank}] Simulation results saved")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,171 @@
import os
os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
import jax
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
if rank == 0:
print(f"SIZE is {jax.device_count()}")
import argparse
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
PIDController, SaveAt, diffeqsolve)
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import (process_allgather,
sync_global_devices)
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
all_gather = partial(process_allgather, tiled=True)
def parse_arguments():
parser = argparse.ArgumentParser(
description="Run a cosmological simulation with JAX.")
parser.add_argument(
"-p",
"--pdims",
type=int,
nargs=2,
default=[1, jax.devices()],
help="Processor grid dimensions as two integers (e.g., 2 4).")
parser.add_argument(
"-m",
"--mesh_shape",
type=int,
nargs=3,
default=[512, 512, 512],
help="Shape of the simulation mesh as three values (e.g., 512 512 512)."
)
parser.add_argument(
"-b",
"--box_size",
type=float,
nargs=3,
default=[500.0, 500.0, 500.0],
help=
"Box size of the simulation as three values (e.g., 500.0 500.0 1000.0)."
)
parser.add_argument("-H",
"--halo_size",
type=int,
default=64,
help="Halo size for the simulation.")
parser.add_argument("-s",
"--solver",
type=str,
choices=['leapfrog', 'dopri8'],
default='leapfrog',
help="ODE solver choice: 'leapfrog' or 'dopri8'.")
return parser.parse_args()
def create_mesh_and_sharding(mesh_shape, pdims):
devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
return mesh, sharding
@partial(jax.jit, static_argnums=(2, 3, 4, 5, 6))
def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size,
solver_choice, sharding):
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape,
box_size,
pk_fn,
seed=jax.random.PRNGKey(0),
sharding=sharding)
particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
axis=-1).reshape([-1, 3])
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
dx, p, _ = lpt(cosmo,
initial_conditions,
a=0.1,
halo_size=halo_size,
sharding=sharding)
ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
term = ODETerm(
lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
# Choose solver
solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()
stepsize_controller = ConstantStepSize(
) if solver_choice == "leapfrog" else PIDController(rtol=1e-5, atol=1e-5)
res = diffeqsolve(term,
solver,
t0=0.1,
t1=1.,
dt0=0.01,
y0=jnp.stack([dx, p], axis=0),
args=cosmo,
saveat=SaveAt(ts=jnp.array([0.5, 1.0])),
stepsize_controller=stepsize_controller)
ode_fields = [
cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding)
for sol in res.ys
]
lpt_field = cic_paint_dx(dx, halo_size=halo_size, sharding=sharding)
return initial_conditions, lpt_field, ode_fields, res.stats
def main():
args = parse_arguments()
mesh_shape = args.mesh_shape
box_size = args.box_size
halo_size = args.halo_size
solver_choice = args.solver
mesh, sharding = create_mesh_and_sharding(mesh_shape, args.pdims)
initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation(
0.25, 0.8, tuple(mesh_shape), tuple(box_size), halo_size,
solver_choice, sharding)
# Save initial conditions
initial_conditions_g = all_gather(initial_conditions)
if rank == 0:
print(f"[{rank}] Saving initial_conditions")
np.save("initial_conditions.npy", initial_conditions_g)
print(f"[{rank}] initial_conditions saved")
del initial_conditions_g, initial_conditions
# Save LPT displacements
lpt_displacements_g = all_gather(lpt_displacements)
if rank == 0:
print(f"[{rank}] Saving lpt_displacements")
np.save("lpt_displacements.npy", lpt_displacements_g)
print(f"[{rank}] lpt_displacements saved")
del lpt_displacements_g, lpt_displacements
# Save each ODE solution separately
for i, sol in enumerate(ode_solutions):
sol_g = all_gather(sol)
if rank == 0:
print(f"[{rank}] Saving ode_solution_{i}")
np.save(f"ode_solution_{i}.npy", sol_g)
print(f"[{rank}] ode_solution_{i} saved")
del sol_g
if __name__ == "__main__":
main()

View file

@ -42,32 +42,79 @@ def plot_fields(fields_dict, sum_over=None):
plt.show()
def plot_fields_single_projection(fields_dict, sum_over=None):
def plot_fields_single_projection(fields_dict, sum_over=None, project_axis=0):
"""
Plots a single projection (along axis 0) of 3D fields in one row,
summing over the first `sum_over` elements along the 0-axis.
Plots a single projection (along axis 0) of 3D fields in a grid,
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
Args:
- fields_dict: dictionary where keys are field names and values are 3D arrays
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_cols = len(fields_dict)
fig, axes = plt.subplots(1, nb_cols, figsize=(5 * nb_cols, 5))
nb_fields = len(fields_dict)
nb_cols = 4 # Set number of images per row
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
fig, axes = plt.subplots(nb_rows,
nb_cols,
figsize=(5 * nb_cols, 5 * nb_rows))
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
for i, (name, field) in enumerate(fields_dict.items()):
row, col = divmod(i, nb_cols)
# Define the slice for the 0-axis projection
slicing = [slice(None)] * field.ndim
slicing[0] = slice(None, sum_over)
slicing[project_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over axis 0 and plot
axes[i].imshow(field[slicing].sum(axis=0) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]])
axes[i].set_xlabel('Mpc/h')
axes[i].set_ylabel('Mpc/h')
axes[i].set_title(f"{name} projection 0")
axes[row, col].imshow(field[slicing].sum(axis=project_axis) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]])
axes[row, col].set_xlabel('Mpc/h')
axes[row, col].set_ylabel('Mpc/h')
axes[row, col].set_title(f"{name} projection 0")
# Remove any empty subplots
for j in range(i + 1, nb_rows * nb_cols):
fig.delaxes(axes.flatten()[j])
plt.tight_layout()
plt.show()
def stack_slices(array):
"""
Stacks 2D slices of an array into a single array based on provided partition dimensions.
Args:
- array_slices: a 2D list of array slices (list of lists format) where
array_slices[i][j] is the slice located at row i, column j in the grid.
- pdims: a tuple representing the grid dimensions (rows, columns).
Returns:
- A single array constructed by stacking the slices.
"""
# Initialize an empty list to store the vertically stacked rows
pdims = array.sharding.mesh.devices.shape
field_slices = []
# Iterate over rows in pdims[0]
for i in range(pdims[0]):
row_slices = []
# Iterate over columns in pdims[1]
for j in range(pdims[1]):
slice_index = i * pdims[0] + j
row_slices.append(array.addressable_data(slice_index))
# Stack the current row of slices vertically
stacked_row = np.hstack(row_slices)
field_slices.append(stacked_row)
# Stack all rows horizontally to form the full array
full_array = np.vstack(field_slices)
return full_array