forked from guilhem_lavaux/JaxPM
remove deprecated stuff
Former-commit-id: da6fdea0556407c968dd31d44af832c997f2645c
This commit is contained in:
parent
bc6e57532d
commit
70ab9f1931
7 changed files with 2418 additions and 14 deletions
|
@ -1,14 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
#SBATCH -A m1727
|
|
||||||
#SBATCH -C gpu
|
|
||||||
#SBATCH -q debug
|
|
||||||
#SBATCH -t 0:05:00
|
|
||||||
#SBATCH -N 2
|
|
||||||
#SBATCH --ntasks-per-node=4
|
|
||||||
#SBATCH -c 32
|
|
||||||
#SBATCH --gpus-per-task=1
|
|
||||||
#SBATCH --gpu-bind=none
|
|
||||||
|
|
||||||
module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit
|
|
||||||
export SLURM_CPU_BIND="cores"
|
|
||||||
srun python test_pfft.py
|
|
309
notebooks/01-Introduction.ipynb
Normal file
309
notebooks/01-Introduction.ipynb
Normal file
File diff suppressed because one or more lines are too long
413
notebooks/02-Advanced_usage.ipynb
Normal file
413
notebooks/02-Advanced_usage.ipynb
Normal file
File diff suppressed because one or more lines are too long
697
notebooks/03-MultiGPU_PM_Halo.ipynb
Normal file
697
notebooks/03-MultiGPU_PM_Halo.ipynb
Normal file
File diff suppressed because one or more lines are too long
379
notebooks/04-MultiGPU_PM_Solvers.ipynb
Normal file
379
notebooks/04-MultiGPU_PM_Solvers.ipynb
Normal file
File diff suppressed because one or more lines are too long
300
notebooks/05-MultiHost_PM.ipynb
Normal file
300
notebooks/05-MultiHost_PM.ipynb
Normal file
File diff suppressed because one or more lines are too long
320
notebooks/06-Animating_PM_Fields.ipynb
Normal file
320
notebooks/06-Animating_PM_Fields.ipynb
Normal file
|
@ -0,0 +1,320 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# **Animating Particle Mesh density fields**\n",
|
||||||
|
"\n",
|
||||||
|
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
|
||||||
|
"\n",
|
||||||
|
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
|
||||||
|
"\n",
|
||||||
|
"just like in notebook [**05-MultiHost_PM.ipynb**]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**A few hours later**\n",
|
||||||
|
"\n",
|
||||||
|
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
|
||||||
|
"\n",
|
||||||
|
"In this example, we’ve been allocated **32 GPUs split across 4 nodes**.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!squeue -u $USER -o \"%i %D %b\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"del os.environ['VSCODE_PROXY_URI']\n",
|
||||||
|
"del os.environ['NO_PROXY']\n",
|
||||||
|
"del os.environ['no_proxy']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Checking Available Compute Resources\n",
|
||||||
|
"\n",
|
||||||
|
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Multi-Host Simulation Script with Arguments (reminder)\n",
|
||||||
|
"\n",
|
||||||
|
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:\n",
|
||||||
|
"\n",
|
||||||
|
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
|
||||||
|
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
|
||||||
|
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
|
||||||
|
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
|
||||||
|
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
|
||||||
|
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
|
||||||
|
"\n",
|
||||||
|
"### Running the Multi-Host Simulation Script\n",
|
||||||
|
"\n",
|
||||||
|
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
|
||||||
|
"\n",
|
||||||
|
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
|
||||||
|
"\n",
|
||||||
|
"The command to run the multi-host simulation with these settings will look something like this:\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import subprocess\n",
|
||||||
|
"\n",
|
||||||
|
"# Define parameters as variables\n",
|
||||||
|
"jobid = \"467745\"\n",
|
||||||
|
"num_processes = 32\n",
|
||||||
|
"script_name = \"05-MultiHost_PM.py\"\n",
|
||||||
|
"mesh_shape = (1024, 1024, 1024)\n",
|
||||||
|
"box_size = (1000., 1000., 1000.)\n",
|
||||||
|
"halo_size = 128\n",
|
||||||
|
"solver = \"leapfrog\"\n",
|
||||||
|
"pdims = (16, 2)\n",
|
||||||
|
"snapshots = 8\n",
|
||||||
|
"\n",
|
||||||
|
"# Build the command as a list, incorporating variables\n",
|
||||||
|
"command = [\n",
|
||||||
|
" \"srun\",\n",
|
||||||
|
" f\"--jobid={jobid}\",\n",
|
||||||
|
" \"-n\", str(num_processes),\n",
|
||||||
|
" \"python\", script_name,\n",
|
||||||
|
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
|
||||||
|
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
|
||||||
|
" \"--halo_size\", str(halo_size),\n",
|
||||||
|
" \"-s\", solver,\n",
|
||||||
|
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
|
||||||
|
" \"--snapshots\", str(snapshots)\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"# Execute the command as a subprocess\n",
|
||||||
|
"subprocess.run(command)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Projecting the 3D Density Fields to 2D\n",
|
||||||
|
"\n",
|
||||||
|
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
|
||||||
|
"\n",
|
||||||
|
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
|
||||||
|
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
|
||||||
|
"\n",
|
||||||
|
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
|
||||||
|
"\n",
|
||||||
|
"### Applying the Magma Colormap\n",
|
||||||
|
"\n",
|
||||||
|
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
|
||||||
|
"\n",
|
||||||
|
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
|
||||||
|
" - First, normalize the array to the `[0, 1]` range.\n",
|
||||||
|
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from matplotlib import colormaps\n",
|
||||||
|
"\n",
|
||||||
|
"# Define a function to project the 3D field to 2D\n",
|
||||||
|
"def project_to_2d(field):\n",
|
||||||
|
" sum_over = field.shape[0] // 8\n",
|
||||||
|
" slicing = [slice(None)] * field.ndim\n",
|
||||||
|
" slicing[0] = slice(None, sum_over)\n",
|
||||||
|
" slicing = tuple(slicing)\n",
|
||||||
|
"\n",
|
||||||
|
" return field[slicing].sum(axis=0)\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def apply_colormap(array, cmap_name=\"magma\"):\n",
|
||||||
|
" cmap = colormaps[cmap_name]\n",
|
||||||
|
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
|
||||||
|
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
|
||||||
|
" return (colored_image * 255).astype(np.uint8)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Loading and Visualizing Results\n",
|
||||||
|
"\n",
|
||||||
|
"After running the multi-host simulation, we load the saved results from disk:\n",
|
||||||
|
"\n",
|
||||||
|
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
|
||||||
|
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
|
||||||
|
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
|
||||||
|
"\n",
|
||||||
|
"We will now project the fields to 2D maps and apply the color map\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
|
||||||
|
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
|
||||||
|
"ode_solutions = []\n",
|
||||||
|
"for i in range(8):\n",
|
||||||
|
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Animating with Manim\n",
|
||||||
|
"\n",
|
||||||
|
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from manim import *\n",
|
||||||
|
"config.media_width = \"100%\"\n",
|
||||||
|
"config.verbosity = \"WARNING\"\n",
|
||||||
|
"config.background_color = \"#00000000\" # Transparent background"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Defining the Animation in Manim\n",
|
||||||
|
"\n",
|
||||||
|
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
|
||||||
|
"\n",
|
||||||
|
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
|
||||||
|
"- **Animation Sequence**:\n",
|
||||||
|
" - The animation begins with a fade-in of the initial conditions.\n",
|
||||||
|
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
|
||||||
|
"\n",
|
||||||
|
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Define the animation in Manim\n",
|
||||||
|
"class FieldTransition(Scene):\n",
|
||||||
|
" def construct(self):\n",
|
||||||
|
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
|
||||||
|
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
|
||||||
|
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
" # Place the images on top of each other initially\n",
|
||||||
|
" lpt_img.move_to(init_conditions_img)\n",
|
||||||
|
" for img in snapshots_imgs:\n",
|
||||||
|
" img.move_to(init_conditions_img)\n",
|
||||||
|
"\n",
|
||||||
|
" # Show initial field and then transform between fields\n",
|
||||||
|
" self.play(FadeIn(init_conditions_img))\n",
|
||||||
|
" self.wait(0.2)\n",
|
||||||
|
" self.play(Transform(init_conditions_img, lpt_img))\n",
|
||||||
|
" self.wait(0.2)\n",
|
||||||
|
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
|
||||||
|
" self.wait(0.2)\n",
|
||||||
|
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
|
||||||
|
" self.play(Transform(img1, img2))\n",
|
||||||
|
" self.wait(0.2)\n",
|
||||||
|
"\n",
|
||||||
|
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue