{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# **Animating Particle Mesh density fields**\n", "\n", "In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n", "\n", "The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n", "\n", "just like in notebook [**05-MultiHost_PM.ipynb**]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**A few hours later**\n", "\n", "Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n", "\n", "In this example, we’ve been allocated **32 GPUs split across 4 nodes**.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!squeue -u $USER -o \"%i %D %b\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "del os.environ['VSCODE_PROXY_URI']\n", "del os.environ['NO_PROXY']\n", "del os.environ['no_proxy']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Checking Available Compute Resources\n", "\n", "Run the following command to initialize JAX distributed computing and display the devices available for this job:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Multi-Host Simulation Script with Arguments (reminder)\n", "\n", "This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:\n", "\n", "- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n", "- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n", "- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n", "- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n", "- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n", "- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n", "\n", "### Running the Multi-Host Simulation Script\n", "\n", "To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n", "\n", "Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n", "\n", "The command to run the multi-host simulation with these settings will look something like this:\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import subprocess\n", "\n", "# Define parameters as variables\n", "jobid = \"467745\"\n", "num_processes = 32\n", "script_name = \"05-MultiHost_PM.py\"\n", "mesh_shape = (1024, 1024, 1024)\n", "box_size = (1000., 1000., 1000.)\n", "halo_size = 128\n", "solver = \"leapfrog\"\n", "pdims = (16, 2)\n", "snapshots = 8\n", "\n", "# Build the command as a list, incorporating variables\n", "command = [\n", " \"srun\",\n", " f\"--jobid={jobid}\",\n", " \"-n\", str(num_processes),\n", " \"python\", script_name,\n", " \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n", " \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n", " \"--halo_size\", str(halo_size),\n", " \"-s\", solver,\n", " \"--pdims\", str(pdims[0]), str(pdims[1]),\n", " \"--snapshots\", str(snapshots)\n", "]\n", "\n", "# Execute the command as a subprocess\n", "subprocess.run(command)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Projecting the 3D Density Fields to 2D\n", "\n", "To visualize the 3D density fields in 2D, we need to create a projection:\n", "\n", "- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n", " - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n", "\n", "- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n", "\n", "### Applying the Magma Colormap\n", "\n", "To improve visualization, apply the \"magma\" colormap to each 2D projection:\n", "\n", "- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n", " - First, normalize the array to the `[0, 1]` range.\n", " - Apply the colormap to create RGB images, which will be used for the animation.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from matplotlib import colormaps\n", "\n", "# Define a function to project the 3D field to 2D\n", "def project_to_2d(field):\n", " sum_over = field.shape[0] // 8\n", " slicing = [slice(None)] * field.ndim\n", " slicing[0] = slice(None, sum_over)\n", " slicing = tuple(slicing)\n", "\n", " return field[slicing].sum(axis=0)\n", "\n", "\n", "def apply_colormap(array, cmap_name=\"magma\"):\n", " cmap = colormaps[cmap_name]\n", " normalized_array = (array - array.min()) / (array.max() - array.min())\n", " colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n", " return (colored_image * 255).astype(np.uint8)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loading and Visualizing Results\n", "\n", "After running the multi-host simulation, we load the saved results from disk:\n", "\n", "- **`initial_conditions.npy`**: Initial conditions for the simulation.\n", "- **`lpt_displacements.npy`**: Linear perturbation displacements.\n", "- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n", "\n", "We will now project the fields to 2D maps and apply the color map\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n", "lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n", "ode_solutions = []\n", "for i in range(8):\n", " ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Animating with Manim\n", "\n", "To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from manim import *\n", "config.media_width = \"100%\"\n", "config.verbosity = \"WARNING\"\n", "config.background_color = \"#00000000\" # Transparent background" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining the Animation in Manim\n", "\n", "This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n", "\n", "- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n", "- **Animation Sequence**:\n", " - The animation begins with a fade-in of the initial conditions.\n", " - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n", "\n", "To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define the animation in Manim\n", "class FieldTransition(Scene):\n", " def construct(self):\n", " init_conditions_img = ImageMobject(initial_conditions).scale(4)\n", " lpt_img = ImageMobject(lpt_displacements).scale(4)\n", " snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n", "\n", "\n", " # Place the images on top of each other initially\n", " lpt_img.move_to(init_conditions_img)\n", " for img in snapshots_imgs:\n", " img.move_to(init_conditions_img)\n", "\n", " # Show initial field and then transform between fields\n", " self.play(FadeIn(init_conditions_img))\n", " self.wait(0.2)\n", " self.play(Transform(init_conditions_img, lpt_img))\n", " self.wait(0.2)\n", " self.play(Transform(lpt_img, snapshots_imgs[0]))\n", " self.wait(0.2)\n", " for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n", " self.play(Transform(img1, img2))\n", " self.wait(0.2)\n", "\n", "%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 2 }