diff --git a/jaxpm/plotting.py b/jaxpm/plotting.py new file mode 100644 index 0000000..5868a7b --- /dev/null +++ b/jaxpm/plotting.py @@ -0,0 +1,120 @@ +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + + +def plot_fields(fields_dict, sum_over=None): + """ + Plots sum projections of 3D fields along different axes, + slicing only the first `sum_over` elements along each axis. + + Args: + - fields: list of 3D arrays representing fields to plot + - names: list of names for each field, used in titles + - sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8) + """ + sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8 + nb_rows = len(fields_dict) + nb_cols = 3 + fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows)) + + def plot_subplots(proj_axis, field, row, title): + slicing = [slice(None)] * field.ndim + slicing[proj_axis] = slice(None, sum_over) + slicing = tuple(slicing) + + # Sum projection over the specified axis and plot + axes[row, proj_axis].imshow( + field[slicing].sum(axis=proj_axis) + 1, + cmap='magma', + extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]]) + axes[row, proj_axis].set_xlabel('Mpc/h') + axes[row, proj_axis].set_ylabel('Mpc/h') + axes[row, proj_axis].set_title(title) + + # Plot each field across the three axes + for i, (name, field) in enumerate(fields_dict.items()): + for proj_axis in range(3): + plot_subplots(proj_axis, field, i, + f"{name} projection {proj_axis}") + + plt.tight_layout() + plt.show() + + +def plot_fields_single_projection(fields_dict, sum_over=None, project_axis=0): + """ + Plots a single projection (along axis 0) of 3D fields in a grid, + summing over the first `sum_over` elements along the 0-axis, with 4 images per row. + + Args: + - fields_dict: dictionary where keys are field names and values are 3D arrays + - sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8) + """ + sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8 + nb_fields = len(fields_dict) + nb_cols = 4 # Set number of images per row + nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows + + fig, axes = plt.subplots(nb_rows, + nb_cols, + figsize=(5 * nb_cols, 5 * nb_rows)) + axes = np.atleast_2d(axes) # Ensure axes is always a 2D array + + for i, (name, field) in enumerate(fields_dict.items()): + row, col = divmod(i, nb_cols) + + # Define the slice for the 0-axis projection + slicing = [slice(None)] * field.ndim + slicing[project_axis] = slice(None, sum_over) + slicing = tuple(slicing) + + # Sum projection over axis 0 and plot + axes[row, col].imshow(field[slicing].sum(axis=project_axis) + 1, + cmap='magma', + extent=[0, field.shape[1], 0, field.shape[2]]) + axes[row, col].set_xlabel('Mpc/h') + axes[row, col].set_ylabel('Mpc/h') + axes[row, col].set_title(f"{name} projection 0") + + # Remove any empty subplots + for j in range(i + 1, nb_rows * nb_cols): + fig.delaxes(axes.flatten()[j]) + + plt.tight_layout() + plt.show() + + +def stack_slices(array): + """ + Stacks 2D slices of an array into a single array based on provided partition dimensions. + + Args: + - array_slices: a 2D list of array slices (list of lists format) where + array_slices[i][j] is the slice located at row i, column j in the grid. + - pdims: a tuple representing the grid dimensions (rows, columns). + + Returns: + - A single array constructed by stacking the slices. + """ + # Initialize an empty list to store the vertically stacked rows + pdims = array.sharding.mesh.devices.shape + + field_slices = [] + + # Iterate over rows in pdims[0] + for i in range(pdims[0]): + row_slices = [] + + # Iterate over columns in pdims[1] + for j in range(pdims[1]): + slice_index = i * pdims[0] + j + row_slices.append(array.addressable_data(slice_index)) + # Stack the current row of slices vertically + stacked_row = np.hstack(row_slices) + field_slices.append(stacked_row) + + # Stack all rows horizontally to form the full array + full_array = np.vstack(field_slices) + + return full_array diff --git a/notebooks/01-Introduction.ipynb b/notebooks/01-Introduction.ipynb index d16d4ed..0cc9712 100644 --- a/notebooks/01-Introduction.ipynb +++ b/notebooks/01-Introduction.ipynb @@ -63,17 +63,18 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "281b4d3b", "metadata": { "id": "281b4d3b" }, "outputs": [], "source": [ - "mesh_shape = [256, 256, 256]\n", - "box_size = [256., 256., 256.]\n", + "mesh_shape = [128, 128, 128]\n", + "box_size = [128., 128., 128.]\n", "snapshots = jnp.array([0.1, 0.5, 1.0])\n", "\n", + "@jax.jit\n", "def run_simulation(omega_c, sigma8):\n", " # Create a small function to generate the matter power spectrum\n", " k = jnp.logspace(-4, 1, 128)\n", @@ -126,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "4e012ce8", "metadata": { "colab": { @@ -149,7 +150,7 @@ } ], "source": [ - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n", "for i , field in enumerate(ode_particles[1:]):\n", @@ -169,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "b71824ed", "metadata": {}, "outputs": [ @@ -182,8 +183,8 @@ } ], "source": [ - "mesh_shape = [256, 256, 256]\n", - "box_size = [256., 256., 256.]\n", + "mesh_shape = [128, 128, 128]\n", + "box_size = [128., 128., 128.]\n", "snapshots = jnp.array([0.1, 0.5, 1.0])\n", "\n", "@jax.jit\n", @@ -216,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "33b5e684", "metadata": {}, "outputs": [ @@ -232,7 +233,8 @@ } ], "source": [ - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", + "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n", "for i , field in enumerate(ode_displacements[1:]):\n", " fields[f\"field_{i}\"] = cic_paint_dx(field)\n", diff --git a/notebooks/02-Advanced_usage.ipynb b/notebooks/02-Advanced_usage.ipynb index c2cc624..4a3e51e 100644 --- a/notebooks/02-Advanced_usage.ipynb +++ b/notebooks/02-Advanced_usage.ipynb @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -63,8 +63,8 @@ } ], "source": [ - "mesh_shape = [256, 256, 256]\n", - "box_size = [256., 256., 256.]\n", + "mesh_shape = [128, 128, 128]\n", + "box_size = [128., 128., 128.]\n", "snapshots = jnp.array([0.5, 1.0])\n", "\n", "@jax.jit\n", @@ -111,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -126,7 +126,7 @@ } ], "source": [ - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n", "for i , field in enumerate(ode_solutions):\n", @@ -145,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -160,8 +160,8 @@ "source": [ "from functools import partial \n", "\n", - "mesh_shape = [256, 256, 256]\n", - "box_size = [256., 256., 256.]\n", + "mesh_shape = [128, 128, 128]\n", + "box_size = [128., 128., 128.]\n", "snapshots = jnp.array([0.5,1.])\n", "\n", "@partial(jax.jit , static_argnums=(2,))\n", @@ -224,7 +224,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -236,8 +236,8 @@ } ], "source": [ - "mesh_shape = [256, 256, 256]\n", - "box_size = [256., 256., 256.]\n", + "mesh_shape = [128, 128, 128]\n", + "box_size = [128., 128., 128.]\n", "snapshots = jnp.array([0.1 ,0.5, 1.])\n", "\n", "@jax.jit\n", @@ -283,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -298,7 +298,7 @@ } ], "source": [ - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n", "for i , field in enumerate(ode_particles[1:]):\n", @@ -321,7 +321,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -337,7 +337,7 @@ ], "source": [ "from math import prod\n", - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n", "center3d = (slice(None) , center,center) \n", @@ -364,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -380,7 +380,7 @@ ], "source": [ "from math import prod\n", - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n", "center3d = (slice(None) , center,center) \n", diff --git a/notebooks/03-MultiGPU_PM_Halo.ipynb b/notebooks/03-MultiGPU_PM_Halo.ipynb index 871aaff..0b47d7c 100644 --- a/notebooks/03-MultiGPU_PM_Halo.ipynb +++ b/notebooks/03-MultiGPU_PM_Halo.ipynb @@ -283,7 +283,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "4e012ce8", "metadata": { "colab": { @@ -306,7 +306,7 @@ } ], "source": [ - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n", "for i , field in enumerate(ode_solutions):\n", @@ -342,7 +342,7 @@ "lpt_displacements_g = all_gather(lpt_displacements)\n", "ode_solutions_g = [all_gather(p) for p in ode_solutions]\n", "\n", - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n", "for i , field in enumerate(ode_solutions):\n", @@ -447,7 +447,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": null, "id": "0acb5253", "metadata": {}, "outputs": [ @@ -463,7 +463,7 @@ } ], "source": [ - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions_g , \"LPT Field\" : lpt_field_g}\n", "for i , field in enumerate(ode_fields_g):\n", @@ -600,7 +600,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "59cfba84", "metadata": {}, "outputs": [ @@ -652,7 +652,7 @@ ], "source": [ "from math import prod\n", - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "field = ode_solutions[0]\n", "\n", diff --git a/notebooks/04-MultiGPU_PM_Solvers.ipynb b/notebooks/04-MultiGPU_PM_Solvers.ipynb index c85f351..6cd268d 100644 --- a/notebooks/04-MultiGPU_PM_Solvers.ipynb +++ b/notebooks/04-MultiGPU_PM_Solvers.ipynb @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -181,7 +181,7 @@ } ], "source": [ - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "\n", "mesh_shape = 1024\n", "box_size = 1000.\n", diff --git a/notebooks/05-MultiHost_PM.ipynb b/notebooks/05-MultiHost_PM.ipynb index 75363dc..a4830da 100644 --- a/notebooks/05-MultiHost_PM.ipynb +++ b/notebooks/05-MultiHost_PM.ipynb @@ -259,7 +259,7 @@ } ], "source": [ - "from visualize import plot_fields_single_projection\n", + "from jaxpm.plotting import plot_fields_single_projection\n", "fields = {\n", " \"Initial Conditions\": initial_conditions,\n", " \"LPT Field\": lpt_displacements,\n",