mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 17:47:11 +00:00
301 lines
1.9 MiB
Text
301 lines
1.9 MiB
Text
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "0877d04e",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# **Multi-Host Particle Mesh Simulation**\n",
|
|||
|
"\n",
|
|||
|
"In this notebook, we extend our Particle Mesh simulation across **multiple nodes**, enabling simulations at scales not achievable on a single machine. By leveraging distributed GPUs across hosts, we handle larger mesh shapes and box sizes efficiently.\n",
|
|||
|
"\n",
|
|||
|
"> **Note**: Since there’s no direct way to run a multi-host notebook, I’ll guide you step by step on how to submit an interactive job from a script.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
|
|||
|
"\n",
|
|||
|
"> **Note**: You can alternatively use `sbatch` with a SLURM script to submit the job. The exact `salloc` parameters may vary depending on your specific HPC cluster configuration.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "c5f42bbe",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "c5f42bbe",
|
|||
|
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
|
|||
|
},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "ac6585f3",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"> **Note**: These `salloc` parameters are configured for the **Jean Zay** supercomputer in France. Adaptations might be necessary if using a different HPC cluster.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "74928ff7",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"**A few hours later**"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "c52f89cc",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
|
|||
|
"\n",
|
|||
|
"In this example, we’ve been allocated **32 GPUs split across 4 nodes**.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"id": "7ebdfc00",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"JOBID NODES TRES_PER_NODE\n",
|
|||
|
"467745 4 gres/gpu:8\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"!squeue -u $USER -o \"%i %D %b\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "88bd7ef8",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "66dbe8f2",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"del os.environ['VSCODE_PROXY_URI']\n",
|
|||
|
"del os.environ['NO_PROXY']\n",
|
|||
|
"del os.environ['no_proxy']"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "36479cc7",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Checking Available Compute Resources\n",
|
|||
|
"\n",
|
|||
|
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"id": "c78b8a4e",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3), CudaDevice(id=4), CudaDevice(id=5), CudaDevice(id=6), CudaDevice(id=7), CudaDevice(id=8), CudaDevice(id=9), CudaDevice(id=10), CudaDevice(id=11), CudaDevice(id=12), CudaDevice(id=13), CudaDevice(id=14), CudaDevice(id=15), CudaDevice(id=16), CudaDevice(id=17), CudaDevice(id=18), CudaDevice(id=19), CudaDevice(id=20), CudaDevice(id=21), CudaDevice(id=22), CudaDevice(id=23), CudaDevice(id=24), CudaDevice(id=25), CudaDevice(id=26), CudaDevice(id=27), CudaDevice(id=28), CudaDevice(id=29), CudaDevice(id=30), CudaDevice(id=31)]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "b24e2e65",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Running the Multi-Host Simulation Script\n",
|
|||
|
"\n",
|
|||
|
"Run the simulation script across 32 processes:\n",
|
|||
|
"\n",
|
|||
|
"```bash\n",
|
|||
|
"!srun --jobid=467745 -n 32 python 05-MultiHost_PM.py --mesh_shape 1024 1024 1024 --box_size 1000. 1000. 1000. --halo_size 128 -s leapfrog --pdims 16 2\n",
|
|||
|
"```\n",
|
|||
|
"The script, located in the same path as this notebook, is named [**05-MultiHost_PM.py**](05-MultiHost_PM.py).\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "df27af30",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Multi-Host Simulation Script with Arguments\n",
|
|||
|
"\n",
|
|||
|
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:\n",
|
|||
|
"\n",
|
|||
|
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
|
|||
|
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
|
|||
|
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
|
|||
|
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
|
|||
|
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
|
|||
|
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
|
|||
|
"\n",
|
|||
|
"The script also saves results across nodes.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "c6d13679",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import subprocess\n",
|
|||
|
"\n",
|
|||
|
"# Define parameters as variables\n",
|
|||
|
"jobid = \"467745\"\n",
|
|||
|
"num_processes = 32\n",
|
|||
|
"script_name = \"05-MultiHost_PM.py\"\n",
|
|||
|
"mesh_shape = (1024, 1024, 1024)\n",
|
|||
|
"box_size = (1000., 1000., 1000.)\n",
|
|||
|
"halo_size = 128\n",
|
|||
|
"solver = \"leapfrog\"\n",
|
|||
|
"pdims = (16, 2)\n",
|
|||
|
"snapshots = 2\n",
|
|||
|
"\n",
|
|||
|
"# Build the command as a list, incorporating variables\n",
|
|||
|
"command = [\n",
|
|||
|
" \"srun\",\n",
|
|||
|
" f\"--jobid={jobid}\",\n",
|
|||
|
" \"-n\", str(num_processes),\n",
|
|||
|
" \"python\", script_name,\n",
|
|||
|
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
|
|||
|
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
|
|||
|
" \"--halo_size\", str(halo_size),\n",
|
|||
|
" \"-s\", solver,\n",
|
|||
|
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
|
|||
|
" \"--snapshots\", str(snapshots)\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"# Execute the command as a subprocess\n",
|
|||
|
"subprocess.run(command)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "45333bf0",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Loading and Visualizing Results\n",
|
|||
|
"\n",
|
|||
|
"After running the multi-host simulation, we load the saved results from disk:\n",
|
|||
|
"\n",
|
|||
|
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
|
|||
|
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
|
|||
|
"- **`ode_solution_0.npy`** and **`ode_solution_1.npy`**: Solutions from the ODE solver at each snapshot.\n",
|
|||
|
"\n",
|
|||
|
"We then use `plot_fields_single_projection` to visualize these fields and observe the results across multiple snapshots.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "472dd4bf",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"initial_conditions = np.load('fields/initial_conditions.npy')\n",
|
|||
|
"lpt_displacements = np.load('fields/lpt_displacements.npy')\n",
|
|||
|
"ode_solution_0 = np.load('fields/ode_solution_0.npy')\n",
|
|||
|
"ode_solution_1 = np.load('fields/ode_solution_1.npy')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "4e012ce8",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 323
|
|||
|
},
|
|||
|
"id": "4e012ce8",
|
|||
|
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
|
|||
|
},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8QAAAHsCAYAAACkFRcHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d5xlRZm4/1TVSTf27TQ93RN6IjnoorgiOKAoQWRRgRUTuAiKuIir7hpWBcWIWVfEXQMrrruKu6D+1rigYgYlSWaY3DM9nW8+qer3x+m+0swMzODgIN96+PSHuefWuadOnXPqfc+bShhjDBaLxWKxWCwWi8VisVgsFovFYrFYLBaLxWKxPMmQ+7oDFovFYrFYLBaLxWKxWCwWi8VisVgsFovFYrE8HliHuMVisVgsFovFYrFYLBaLxWKxWCwWi8VisVielFiHuMVisVgsFovFYrFYLBaLxWKxWCwWi8VisVielFiHuMVisVgsFovFYrFYLBaLxWKxWCwWi8VisVielFiHuMVisVgsFovFYrFYLBaLxWKxWCwWi8VisVielFiHuMVisVgsFovFYrFYLBaLxWKxWCwWi8VisVielFiHuMVisVgsFovFYrFYLBaLxWKxWCwWi8VisVielFiHuMVisVgsFovFYrFYLBaLxWKxWCwWi8VisVielFiHuMVisVgsFovFYrFYLBaLxWKxWCwWi8VisVielFiHuOUvEiEEl1xyyW61XbZsGeecc84eH2P9+vUIIfjKV76yx/s+EfjJT36CEIKf/OQnnW3nnHMOy5Yt2639L7nkEoQQj0/n/oLYl/fBnlwvi8VisTzx2Jks3l2+8pWvIIRg/fr1j9r2seo6jyf7qk97Mm4Wi8Vi+X+XP0VGPxL2He7RebzGfnc49thjOfbYY//sx7VYLBbL3sHK732HtZFbngxYh7hlnzBnrLz55pv3yu/98pe/5JJLLmF6enqv/N5jYXR0lLe85S0ccMAB5PN5CoUCRxxxBJdddtk+7dcj0Ww2ueSSS/bJi6glY2RkhEsuuYRbb711X3dlB774xS9y4IEHEgQBq1ev5jOf+cy+7pLFYrHslg4x96I296eUYunSpbzoRS/qzLfnnHPOvDa7+nskp+5c8NjO/j7/+c/v5TO3PJQPfOADXHvttfu6Gzvwy1/+kqOPPpp8Ps/ChQu56KKLqNfr+7pbFovF8qjceeedvOIVr2DRokX4vs/Q0BAvf/nLufPOO3doOyeL5/6CIGBoaIgTTjiBT3/609RqtR32eSSZKYRg27Ztj9i/KIr41Kc+xVOf+lTK5TKVSoWDDz6Y888/n3vuuWevjcPuYN/hntjcddddXHLJJU+44DitNR/5yEdYvnw5QRBw2GGH8fWvf31fd8tisfyFY+X37vNEld9XXHEFZ5xxBkuXLn1UG8STmSfq9QGrXz2ZcPZ1ByyWx0Kr1cJx/nj7/vKXv+TSSy/lnHPOoVKpzGt77733IuXjG/tx0003cfLJJ1Ov13nFK17BEUccAcDNN9/Mhz70IX72s5/xwx/+8HHtw+7wr//6r2itO5+bzSaXXnopwA5R0v/8z//M2972tj9n956QDA8P02q1cF33cfn9kZERLr30UpYtW8ZTnvKUed89/Hr9Obnyyit53etex0te8hL+4R/+gRtvvJGLLrqIZrPJP/3TP+2TPlksFsuectZZZ3HyySeTpil33303V1xxBd/73vf49a9/zWtf+1qOP/74Ttt169bx7ne/m/PPP59jjjmms33lypWPepwrrriCYrE4b9sznvEMVq5cSavVwvO8vXdSfyE83vrXBz7wAU4//XROO+20edtf+cpX8tKXvhTf9x+3Y++KW2+9lec+97kceOCBfPzjH2fz5s189KMf5f777+d73/ven70/FovFsrv893//N2eddRY9PT2ce+65LF++nPXr1/PFL36Ra665hv/8z//kRS960Q77vfe972X58uXEccy2bdv4yU9+wsUXX8zHP/5xvv3tb3PYYYftsM/OZCaww3v8w3nJS17C9773Pc466yzOO+884jjmnnvu4bvf/S5HHXUUBxxwwGM+/z3FvsP9aTz72c9+XPWju+66i0svvZRjjz12h2yyfWmXeec738mHPvQhzjvvPJ7+9Kdz3XXX8bKXvQwhBC996Uv3Wb8sFstfLlZ+7xlPVPn94Q9/mFqtxpFHHsnWrVv3SR92B2sjf2LrV5bdwzrELX+RBEGw220fb4Po9PQ0L3rRi1BKccstt+wgyN///vfzr//6r49rH3aXPRFYjuPMCzp4smCMod1uk8vldqv9XMTkvuDxUjAejVarxTvf+U5e8IIXcM011wBw3nnnobXmfe97H+effz7d3d37pG8Wi8WyJ/zVX/0Vr3jFKzqfn/WsZ3HqqadyxRVXcOWVV/LMZz6z893NN9/Mu9/9bp75zGfO22d3OP300+nr69vpd/tKhuxtGo0GhUJht9vvC4c0gFIKpdQ+OfY73vEOuru7+clPfkK5XAay0vHnnXceP/zhD3n+85+/T/plsVgsj8TatWt55StfyYoVK/jZz35Gf39/57s3vvGNHHPMMbzyla/k9ttvZ8WKFfP2Pemkk3ja057W+fz2t7+d66+/nlNOOYVTTz2Vu+++e4f3rkeSmbvipptu4rvf/S7vf//7ecc73jHvu89+9rNPqIps/y++w2mtiaJot3UeKeU+04/2VZDili1b+NjHPsaFF17IZz/7WQBe85rXsGbNGt761rdyxhln7DP9xWKx/GVi5ffeZV/Jb4Cf/vSnnezwnQUdPF5YG/mjY23kTz5syXTLE4ZzzjmHYrHIli1bOO200ygWi/T39/OWt7yFNE3ntX3oGuKXXHIJb33rWwFYvnx5p2TLXHmsh69hOTk5yVve8hYOPfRQisUi5XKZk046idtuu+0x9fvKK69ky5YtfPzjH99pVNvAwAD//M//PG/b5z73OQ4++OBOKZsLL7xwByXg2GOP5ZBDDuGuu+7iuOOOI5/Ps2jRIj7ykY/scIzNmzdz2mmnUSgUWLBgAW9605sIw3CHdg9db2P9+vUdZenSSy/tjNtDx/Xha4gnScL73vc+Vq5cie/7LFu2jHe84x07HGvZsmWccsop/PznP+fII48kCAJWrFjBv//7v89rF8cxl156KatXryYIAnp7ezn66KP50Y9+tONAP4S5Ej8/+9nPeO1rX0tvby/lcplXvepVTE1N7bQvP/jBD3ja055GLpfjyiuvBODBBx/kjDPOoKenh3w+z1//9V/z//1//9+8/Xe1Pso999zD6aefTk9PD0EQ8LSnPY1vf/vbO/R1enqaN73pTSxbtgzf91m8eDGvetWrGB8f5yc/+QlPf/rTAXj1q1/duQZzx9rZ+iiNRoM3v/nNLFmyBN/32X///fnoRz+KMWZeOyEEb3jDG7j22ms55JBD8H2fgw8+mO9///uPOLYAN9xwAxMTE7z+9a+ft/3CCy+k0WjsMEYWi8Xyl8JznvMcIMsG/3Owq/XNfvOb33DiiSfS1dVFPp9nzZo1/OIXv3jU3zPGcNlll7F48WLy+TzHHXfcTkvh7Yw5efbRj36UT3ziEwwPD5PL5VizZg1/+MMf5rWd08nWrl3LySefTKlU4uUvfzmw+3JoZ2uIT09Pc/HFF3f2XbVqFR/+8Id3iPTWWvOpT32KQw89lCAI6O/v58QTT+yUyRdC0Gg0uOqqq3Yobb+rNcT3tu71cKrVKj/60Y94xSte0XGGA7zqVa+iWCzyjW9841F/w2KxWPYFl19+Oc1mky984QvzjOkAfX19XHnllTQajd2aCyGTte9617vYsGEDV1999V7p49q1a4EssO3hKKXo7e2dt+2WW27hpJNOolwuUywWee5zn8uvf/3rRz3OzmQXzF93+sn6Djf3/n/PPfdw5plnUi6X6e3t5Y1vfCPtdnun/fza177Wka1zfdydsf9T9aMtW7Zw7rnnMjQ0hO/7LF++nAsuuIAoivjKV77CGWecAcBxxx3XuT5zx9rZGuLbt2/n3HPPZWBggCAIOPzww7nqqqvmtXm
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 2000x500 with 4 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
|||
|
"fields = {\n",
|
|||
|
" \"Initial Conditions\": initial_conditions,\n",
|
|||
|
" \"LPT Field\": lpt_displacements,\n",
|
|||
|
" \"ODE Solution 0\": ode_solution_0,\n",
|
|||
|
" \"ODE Solution 1\": ode_solution_1\n",
|
|||
|
"}\n",
|
|||
|
"plot_fields_single_projection(fields)"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"accelerator": "GPU",
|
|||
|
"colab": {
|
|||
|
"include_colab_link": true,
|
|||
|
"name": "Introduction.ipynb",
|
|||
|
"provenance": []
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.10.4"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|