From 49dd18a3f82a66b4fea334f251a492e46b8baac0 Mon Sep 17 00:00:00 2001
From: Wassim KABALAN <wastondev@gmail.com>
Date: Sat, 26 Oct 2024 22:47:26 +0200
Subject: [PATCH] add notebook examples

---
 notebooks/01-Introduction.ipynb | 320 ++++++++++++++++++++++++++++++++
 notebooks/02-MultiGPU_PM.ipynb  | 242 ++++++++++++++++++++++++
 notebooks/03-MultiHost_PM.ipynb | 157 ++++++++++++++++
 notebooks/03-MultiHost_PM.py    | 104 +++++++++++
 notebooks/Introduction.ipynb    | 192 -------------------
 notebooks/visualize.py          |  38 ++++
 6 files changed, 861 insertions(+), 192 deletions(-)
 create mode 100644 notebooks/01-Introduction.ipynb
 create mode 100644 notebooks/02-MultiGPU_PM.ipynb
 create mode 100644 notebooks/03-MultiHost_PM.ipynb
 create mode 100644 notebooks/03-MultiHost_PM.py
 delete mode 100644 notebooks/Introduction.ipynb
 create mode 100644 notebooks/visualize.py

diff --git a/notebooks/01-Introduction.ipynb b/notebooks/01-Introduction.ipynb
new file mode 100644
index 0000000..8b8dc8a
--- /dev/null
+++ b/notebooks/01-Introduction.ipynb
@@ -0,0 +1,320 @@
+{
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "view-in-github"
+      },
+      "source": [
+        "<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 8,
+      "id": "9Jy5BL1XiK1s",
+      "metadata": {
+        "id": "9Jy5BL1XiK1s"
+      },
+      "outputs": [],
+      "source": [
+        "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 1,
+      "id": "c5f42bbe",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "c5f42bbe",
+        "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
+      },
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
+            "Populating the interactive namespace from numpy and matplotlib\n"
+          ]
+        }
+      ],
+      "source": [
+        "import os\n",
+        "import jax\n",
+        "import jax.numpy as jnp\n",
+        "import jax_cosmo as jc\n",
+        "\n",
+        "from jax.experimental.ode import odeint\n",
+        "\n",
+        "from jaxpm.painting import cic_paint\n",
+        "from jaxpm.pm import linear_field, lpt, make_ode_fn"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 2,
+      "id": "281b4d3b",
+      "metadata": {
+        "id": "281b4d3b"
+      },
+      "outputs": [],
+      "source": [
+        "mesh_shape= [256, 256, 256]\n",
+        "box_size  = [256.,256.,256.]\n",
+        "snapshots = jnp.linspace(0.1,1.,3)\n",
+        "\n",
+        "@jax.jit\n",
+        "def run_simulation(omega_c, sigma8):\n",
+        "    # Create a small function to generate the matter power spectrum\n",
+        "    k = jnp.logspace(-4, 1, 128)\n",
+        "    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
+        "    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
+        "\n",
+        "    # Create initial conditions\n",
+        "    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
+        "\n",
+        "    # Create particles\n",
+        "    particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1 , 3])\n",
+        "    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
+        "    \n",
+        "    # Initial displacement\n",
+        "    dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
+        "    \n",
+        "    # Evolve the simulation forward\n",
+        "    res = odeint(make_ode_fn(mesh_shape,particles), [particles + dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
+        "    \n",
+        "    # Return the simulation volume at requested \n",
+        "\n",
+        "    return initial_conditions ,  particles + dx , res[0]"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "826be667",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "826be667",
+        "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
+      },
+      "outputs": [],
+      "source": [
+        "initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)\n",
+        "%timeit initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "4e012ce8",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 323
+        },
+        "id": "4e012ce8",
+        "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
+      },
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "shape of grid_mesh: (256, 256, 256)\n"
+          ]
+        }
+      ],
+      "source": [
+        "from visualize import plot_fields\n",
+        "\n",
+        "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
+        "for i , field in enumerate(ode_particles):\n",
+        "    fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
+        "plot_fields(fields)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "b71824ed",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "mesh_shape= [256, 256, 256]\n",
+        "box_size  = [256.,256.,256.]\n",
+        "snapshots = jnp.linspace(0.1,1.,3)\n",
+        "\n",
+        "@jax.jit\n",
+        "def run_simulation(omega_c, sigma8):\n",
+        "    # Create a small function to generate the matter power spectrum\n",
+        "    k = jnp.logspace(-4, 1, 128)\n",
+        "    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
+        "    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
+        "\n",
+        "    # Create initial conditions\n",
+        "    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
+        "\n",
+        "    # Create particles\n",
+        "    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
+        "    \n",
+        "    # Initial displacement\n",
+        "    dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n",
+        "    \n",
+        "    # Evolve the simulation forward\n",
+        "    res = odeint(make_ode_fn(mesh_shape), [dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
+        "    \n",
+        "    # Return the simulation volume at requested \n",
+        "\n",
+        "    return initial_conditions ,  dx , res[0]"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "e9c9fd56",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)\n",
+        "%timeit initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "33b5e684",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "from visualize import plot_fields\n",
+        "\n",
+        "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n",
+        "for i , field in enumerate(ode_displacements):\n",
+        "    fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
+        "plot_fields(fields)"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "4e050871",
+      "metadata": {
+        "id": "4e050871"
+      },
+      "outputs": [],
+      "source": [
+        "!pip install diffrax\n",
+        "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "43504a1b",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "mesh_shape= [256, 256, 256]\n",
+        "box_size  = [256.,256.,256.]\n",
+        "snapshots = jnp.linspace(0.1,1.,3)\n",
+        "\n",
+        "@jax.jit\n",
+        "def run_simulation(omega_c, sigma8):\n",
+        "    # Create a small function to generate the matter power spectrum\n",
+        "    k = jnp.logspace(-4, 1, 128)\n",
+        "    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
+        "    pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
+        "\n",
+        "    # Create initial conditions\n",
+        "    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
+        "\n",
+        "    # Create particles\n",
+        "    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
+        "    \n",
+        "    # Initial displacement\n",
+        "    dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n",
+        "    \n",
+        "    # Evolve the simulation forward\n",
+        "    ode_fn = make_ode_fn(mesh_shape)\n",
+        "    term = ODETerm(\n",
+        "        lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
+        "    solver = LeapfrogMidpoint()\n",
+        "\n",
+        "    stepsize_controller = ConstantStepSize()\n",
+        "    res = diffeqsolve(term,\n",
+        "                      solver,\n",
+        "                      t0=0.1,\n",
+        "                      t1=1.,\n",
+        "                      dt0=0.01,\n",
+        "                      y0=jnp.stack([dx, p], axis=0),\n",
+        "                      args=cosmo,\n",
+        "                      saveat=SaveAt(ts=snapshots),\n",
+        "                      stepsize_controller=stepsize_controller)\n",
+        "\n",
+        "\n",
+        "    return initial_conditions ,  dx , res.ys , res.stats"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "19949ff1",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
+        "%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
+        "print(f\"Solver Stats : {solver_stats}\")"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "76a26e98",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "from visualize import plot_fields\n",
+        "\n",
+        "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n",
+        "for i , field in enumerate(ode_solutions):\n",
+        "    fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field[0])\n",
+        "plot_fields(fields)"
+      ]
+    }
+  ],
+  "metadata": {
+    "accelerator": "GPU",
+    "colab": {
+      "include_colab_link": true,
+      "name": "Introduction.ipynb",
+      "provenance": []
+    },
+    "kernelspec": {
+      "display_name": "Python 3",
+      "language": "python",
+      "name": "python3"
+    },
+    "language_info": {
+      "codemirror_mode": {
+        "name": "ipython",
+        "version": 3
+      },
+      "file_extension": ".py",
+      "mimetype": "text/x-python",
+      "name": "python",
+      "nbconvert_exporter": "python",
+      "pygments_lexer": "ipython3",
+      "version": "3.10.4"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 5
+}
diff --git a/notebooks/02-MultiGPU_PM.ipynb b/notebooks/02-MultiGPU_PM.ipynb
new file mode 100644
index 0000000..bf64b39
--- /dev/null
+++ b/notebooks/02-MultiGPU_PM.ipynb
@@ -0,0 +1,242 @@
+{
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "colab_type": "text",
+        "id": "view-in-github"
+      },
+      "source": [
+        "<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 8,
+      "id": "9Jy5BL1XiK1s",
+      "metadata": {
+        "id": "9Jy5BL1XiK1s"
+      },
+      "outputs": [],
+      "source": [
+        "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
+        "!pip install diffrax"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 1,
+      "id": "c5f42bbe",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "c5f42bbe",
+        "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
+      },
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
+            "Populating the interactive namespace from numpy and matplotlib\n"
+          ]
+        }
+      ],
+      "source": [
+        "import os\n",
+        "import jax\n",
+        "import jax.numpy as jnp\n",
+        "import jax_cosmo as jc\n",
+        "\n",
+        "from jax.experimental.ode import odeint\n",
+        "from jaxpm.kernels import interpolate_power_spectrum\n",
+        "from jaxpm.painting import cic_paint\n",
+        "from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
+        "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "38df34e3",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "assert jax.device_count() >= 8, \"This notebook requires a TPU or GPU runtime with 8 devices\""
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "9edd2246",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "from jax.experimental.mesh_utils import create_device_mesh\n",
+        "from jax.experimental.multihost_utils import process_allgather\n",
+        "from jax.sharding import Mesh, NamedSharding\n",
+        "from jax.sharding import PartitionSpec as P\n",
+        "from functools import partial\n",
+        "\n",
+        "all_gather = partial(process_allgather, tiled=True)\n",
+        "\n",
+        "pdims = (2, 4)\n",
+        "devices = create_device_mesh(pdims)\n",
+        "mesh = Mesh(devices, axis_names=('x', 'y'))\n",
+        "sharding = NamedSharding(mesh, P('x', 'y'))"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 2,
+      "id": "281b4d3b",
+      "metadata": {
+        "id": "281b4d3b"
+      },
+      "outputs": [],
+      "source": [
+        "mesh_shape = [1024, 1024, 1024]\n",
+        "box_size = [1024., 1024., 1024.]\n",
+        "halo_size = 128\n",
+        "snapshots = jnp.linspace(0.1, 1., 3)\n",
+        "\n",
+        "\n",
+        "@jax.jit\n",
+        "def run_simulation(omega_c, sigma8):\n",
+        "    # Create a small function to generate the matter power spectrum\n",
+        "    k = jnp.logspace(-4, 1, 128)\n",
+        "    pk = jc.power.linear_matter_power(\n",
+        "        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
+        "    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n",
+        "\n",
+        "    # Create initial conditions\n",
+        "    initial_conditions = linear_field(mesh_shape,\n",
+        "                                      box_size,\n",
+        "                                      pk_fn,\n",
+        "                                      seed=jax.random.PRNGKey(0),\n",
+        "                                      sharding=sharding)\n",
+        "\n",
+        "    # Create particles\n",
+        "    particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),\n",
+        "                          axis=-1).reshape([-1, 3])\n",
+        "    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
+        "\n",
+        "    # Initial displacement\n",
+        "    dx, p, f = lpt(cosmo,\n",
+        "                   initial_conditions,\n",
+        "                   particles,\n",
+        "                   0.1,\n",
+        "                   halo_size=halo_size,\n",
+        "                   sharding=sharding)\n",
+        "\n",
+        "    # Evolve the simulation forward\n",
+        "    ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
+        "    term = ODETerm(\n",
+        "        lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
+        "    solver = LeapfrogMidpoint()\n",
+        "\n",
+        "    stepsize_controller = ConstantStepSize()\n",
+        "    res = diffeqsolve(term,\n",
+        "                      solver,\n",
+        "                      t0=0.1,\n",
+        "                      t1=1.,\n",
+        "                      dt0=0.01,\n",
+        "                      y0=jnp.stack([dx, p], axis=0),\n",
+        "                      args=cosmo,\n",
+        "                      saveat=SaveAt(ts=snapshots),\n",
+        "                      stepsize_controller=stepsize_controller)\n",
+        "\n",
+        "    return initial_conditions, dx, res.ys, res.stats"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "826be667",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "826be667",
+        "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
+      },
+      "outputs": [],
+      "source": [
+        "initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
+        "%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
+        "print(f\"Solver Stats : {solver_stats}\")"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "042cc55c",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "initial_conditions = all_gather(initial_conditions)\n",
+        "lpt_particles = all_gather(lpt_particles)\n",
+        "ode_particles = [all_gather(p) for p in ode_particles]"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "4e012ce8",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 323
+        },
+        "id": "4e012ce8",
+        "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
+      },
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "shape of grid_mesh: (256, 256, 256)\n"
+          ]
+        }
+      ],
+      "source": [
+        "from visualize import plot_fields\n",
+        "\n",
+        "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
+        "for i , field in enumerate(ode_particles):\n",
+        "    fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
+        "plot_fields(fields)"
+      ]
+    }
+  ],
+  "metadata": {
+    "accelerator": "GPU",
+    "colab": {
+      "include_colab_link": true,
+      "name": "Introduction.ipynb",
+      "provenance": []
+    },
+    "kernelspec": {
+      "display_name": "Python 3",
+      "language": "python",
+      "name": "python3"
+    },
+    "language_info": {
+      "codemirror_mode": {
+        "name": "ipython",
+        "version": 3
+      },
+      "file_extension": ".py",
+      "mimetype": "text/x-python",
+      "name": "python",
+      "nbconvert_exporter": "python",
+      "pygments_lexer": "ipython3",
+      "version": "3.10.4"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 5
+}
diff --git a/notebooks/03-MultiHost_PM.ipynb b/notebooks/03-MultiHost_PM.ipynb
new file mode 100644
index 0000000..b7f1683
--- /dev/null
+++ b/notebooks/03-MultiHost_PM.ipynb
@@ -0,0 +1,157 @@
+{
+  "cells": [
+    {
+      "cell_type": "markdown",
+      "id": "22803ddc",
+      "metadata": {
+        "colab_type": "text",
+        "id": "view-in-github"
+      },
+      "source": [
+        "<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 8,
+      "id": "9Jy5BL1XiK1s",
+      "metadata": {
+        "id": "9Jy5BL1XiK1s"
+      },
+      "outputs": [],
+      "source": [
+        "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
+        "!pip install diffrax"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 1,
+      "id": "c5f42bbe",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "c5f42bbe",
+        "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
+      },
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
+            "Populating the interactive namespace from numpy and matplotlib\n"
+          ]
+        }
+      ],
+      "source": [
+        "!salloc --account=tkc@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:30:00  --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=2"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "7ebdfc00",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "!squeue -u $USER"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "c014316c",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "export JOB_ID=123456"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "b7eabac5",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "!srun --jobid=$JOB_ID -n 16 python 03-MultiHost_PM.py"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "472dd4bf",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "import numpy as np\n",
+        "\n",
+        "data = np.load(\"multihost_pm.npz\")\n",
+        "initial_conditions = data['initial_conditions']\n",
+        "lpt_displacements = data['lpt_displacements']\n",
+        "ode_solutions = data['ode_solutions']\n",
+        "solver_stats = data['solver_stats']\n",
+        "print(f\"Solver stats: {solver_stats}\")"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "id": "4e012ce8",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 323
+        },
+        "id": "4e012ce8",
+        "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
+      },
+      "outputs": [
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "shape of grid_mesh: (256, 256, 256)\n"
+          ]
+        }
+      ],
+      "source": [
+        "from visualize import plot_fields\n",
+        "\n",
+        "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
+        "for i , field in enumerate(ode_particles):\n",
+        "    fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
+        "plot_fields(fields)"
+      ]
+    }
+  ],
+  "metadata": {
+    "accelerator": "GPU",
+    "colab": {
+      "include_colab_link": true,
+      "name": "Introduction.ipynb",
+      "provenance": []
+    },
+    "kernelspec": {
+      "display_name": "Python 3",
+      "language": "python",
+      "name": "python3"
+    },
+    "language_info": {
+      "codemirror_mode": {
+        "name": "ipython",
+        "version": 3
+      },
+      "file_extension": ".py",
+      "mimetype": "text/x-python",
+      "name": "python",
+      "nbconvert_exporter": "python",
+      "pygments_lexer": "ipython3",
+      "version": "3.10.4"
+    }
+  },
+  "nbformat": 4,
+  "nbformat_minor": 5
+}
diff --git a/notebooks/03-MultiHost_PM.py b/notebooks/03-MultiHost_PM.py
new file mode 100644
index 0000000..06dccb7
--- /dev/null
+++ b/notebooks/03-MultiHost_PM.py
@@ -0,0 +1,104 @@
+import os
+
+os.environ["EQX_ON_ERROR"] = "nan"  # avoid an allgather caused by diffrax
+import jax
+
+jax.distributed.initialize()
+rank = jax.process_index()
+size = jax.process_count()
+
+import jax.numpy as jnp
+import jax_cosmo as jc
+
+from jaxpm.kernels import interpolate_power_spectrum
+from jaxpm.painting import cic_paint_dx
+from jaxpm.pm import linear_field, lpt, make_ode_fn
+
+from jax.experimental.mesh_utils import create_device_mesh
+from jax.experimental.multihost_utils import process_allgather
+from jax.sharding import Mesh, NamedSharding
+from jax.sharding import PartitionSpec as P
+from functools import partial
+import numpy as np
+
+from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
+
+all_gather = partial(process_allgather, tiled=True)
+
+pdims = (2, 4)
+devices = create_device_mesh(pdims)
+mesh = Mesh(devices, axis_names=('x', 'y'))
+sharding = NamedSharding(mesh, P('x', 'y'))
+
+mesh_shape = [2024, 1024, 1024]
+box_size = [1024., 1024., 1024.]
+halo_size = 512
+snapshots = jnp.linspace(0.1, 1., 2)
+
+
+@jax.jit
+def run_simulation(omega_c, sigma8):
+    # Create a small function to generate the matter power spectrum
+    k = jnp.logspace(-4, 1, 128)
+    pk = jc.power.linear_matter_power(
+        jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
+    pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
+
+    # Create initial conditions
+    initial_conditions = linear_field(mesh_shape,
+                                      box_size,
+                                      pk_fn,
+                                      seed=jax.random.PRNGKey(0),
+                                      sharding=sharding)
+
+    # Create particles
+    particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
+                          axis=-1).reshape([-1, 3])
+    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
+
+    # Initial displacement
+    dx, p, _ = lpt(cosmo,
+                   initial_conditions,
+                   particles,
+                   0.1,
+                   halo_size=halo_size,
+                   sharding=sharding)
+
+    # Evolve the simulation forward
+    ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
+    term = ODETerm(
+        lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
+    solver = LeapfrogMidpoint()
+
+    stepsize_controller = ConstantStepSize()
+    res = diffeqsolve(term,
+                      solver,
+                      t0=0.1,
+                      t1=1.,
+                      dt0=0.01,
+                      y0=jnp.stack([dx, p], axis=0),
+                      args=cosmo,
+                      saveat=SaveAt(ts=snapshots),
+                      stepsize_controller=stepsize_controller)
+
+    return initial_conditions, dx, res.ys, res.stats
+
+
+initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
+print(f"[{rank}] Simulation completed")
+print(f"[{rank}] Solver stats: {solver_stats}")
+
+# Gather the results
+initial_conditions = all_gather(initial_conditions)
+lpt_displacements = all_gather(lpt_displacements)
+ode_solutions = [all_gather(sol) for sol in ode_solutions]
+
+if rank == 0:
+    np.savez("multihost_pm.npz",
+                initial_conditions=initial_conditions,
+                lpt_displacements=lpt_displacements,
+                ode_solutions=ode_solutions,
+                solver_stats=solver_stats)
+
+print(f"[{rank}] Simulation results saved")
+    
\ No newline at end of file
diff --git a/notebooks/Introduction.ipynb b/notebooks/Introduction.ipynb
deleted file mode 100644
index aa49d41..0000000
--- a/notebooks/Introduction.ipynb
+++ /dev/null
@@ -1,192 +0,0 @@
-{
-  "cells": [
-    {
-      "cell_type": "markdown",
-      "metadata": {
-        "colab_type": "text",
-        "id": "view-in-github"
-      },
-      "source": [
-        "<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/Introduction.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 8,
-      "id": "9Jy5BL1XiK1s",
-      "metadata": {
-        "id": "9Jy5BL1XiK1s"
-      },
-      "outputs": [],
-      "source": [
-        "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git@ASKabalan/jaxdecomp_proto"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 1,
-      "id": "c5f42bbe",
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "c5f42bbe",
-        "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
-      },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
-            "Populating the interactive namespace from numpy and matplotlib\n"
-          ]
-        }
-      ],
-      "source": [
-        "%pylab inline\n",
-        "import os\n",
-        "import jax\n",
-        "import jax.numpy as jnp\n",
-        "import jax_cosmo as jc\n",
-        "\n",
-        "from jax.experimental.ode import odeint\n",
-        "\n",
-        "from jaxpm.painting import cic_paint\n",
-        "from jaxpm.pm import linear_field, lpt, make_ode_fn"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 2,
-      "id": "281b4d3b",
-      "metadata": {
-        "id": "281b4d3b"
-      },
-      "outputs": [],
-      "source": [
-        "mesh_shape= [256, 256, 256]\n",
-        "box_size  = [256.,256.,256.]\n",
-        "snapshots = jnp.linspace(0.1,1.,2)\n",
-        "\n",
-        "@jax.jit\n",
-        "def run_simulation(omega_c, sigma8):\n",
-        "    # Create a small function to generate the matter power spectrum\n",
-        "    k = jnp.logspace(-4, 1, 128)\n",
-        "    pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
-        "    pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
-        "\n",
-        "    # Create initial conditions\n",
-        "    initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
-        "\n",
-        "    # Create particles\n",
-        "    particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])\n",
-        "\n",
-        "    cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
-        "    \n",
-        "    # Initial displacement\n",
-        "    dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
-        "    \n",
-        "    # Evolve the simulation forward\n",
-        "    res = odeint(make_ode_fn(mesh_shape), [particles+dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
-        "    \n",
-        "    # Return the simulation volume at requested \n",
-        "    return res[0]"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": 3,
-      "id": "826be667",
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "826be667",
-        "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
-      },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "particles are Traced<ShapedArray(int32[16777216,3])>with<DynamicJaxprTrace(level=1/0)>\n",
-            "pm_forces particles are Traced<ShapedArray(int32[16777216,3])>with<DynamicJaxprTrace(level=1/0)>\n",
-            "shape of displacement: (256, 256, 256)\n",
-            "pm_forces particles are Traced<ShapedArray(float32[16777216,3])>with<DynamicJaxprTrace(level=2/0)>\n"
-          ]
-        }
-      ],
-      "source": [
-        "res = run_simulation(0.25, 0.8)\n",
-        "#%timeit res = run_simulation(0.25, 0.8)"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "id": "4e012ce8",
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/",
-          "height": 323
-        },
-        "id": "4e012ce8",
-        "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
-      },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "shape of grid_mesh: (256, 256, 256)\n"
-          ]
-        }
-      ],
-      "source": [
-        "figure(figsize=[10,5])\n",
-        "subplot(121)\n",
-        "imshow(cic_paint(jnp.zeros(mesh_shape), res[0]).sum(axis=0),cmap='magma')\n",
-        "subplot(122)\n",
-        "imshow(cic_paint(jnp.zeros(mesh_shape), res[1]).sum(axis=0),cmap='magma')"
-      ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "id": "4e050871",
-      "metadata": {
-        "id": "4e050871"
-      },
-      "outputs": [],
-      "source": []
-    }
-  ],
-  "metadata": {
-    "accelerator": "GPU",
-    "colab": {
-      "include_colab_link": true,
-      "name": "Introduction.ipynb",
-      "provenance": []
-    },
-    "kernelspec": {
-      "display_name": "Python 3",
-      "language": "python",
-      "name": "python3"
-    },
-    "language_info": {
-      "codemirror_mode": {
-        "name": "ipython",
-        "version": 3
-      },
-      "file_extension": ".py",
-      "mimetype": "text/x-python",
-      "name": "python",
-      "nbconvert_exporter": "python",
-      "pygments_lexer": "ipython3",
-      "version": "3.10.4"
-    }
-  },
-  "nbformat": 4,
-  "nbformat_minor": 5
-}
diff --git a/notebooks/visualize.py b/notebooks/visualize.py
new file mode 100644
index 0000000..0db2273
--- /dev/null
+++ b/notebooks/visualize.py
@@ -0,0 +1,38 @@
+import numpy as np
+import jax.numpy as jnp
+import matplotlib.pyplot as plt
+
+def plot_fields(fields_dict, sum_over=None):
+    """
+    Plots sum projections of 3D fields along different axes,
+    slicing only the first `sum_over` elements along each axis.
+
+    Args:
+    - fields: list of 3D arrays representing fields to plot
+    - names: list of names for each field, used in titles
+    - sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
+    """
+    sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
+    nb_rows = len(fields_dict)
+    nb_cols = 3
+    fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
+
+    def plot_subplots(proj_axis, field, row, title):
+        slicing = [slice(None)] * field.ndim
+        slicing[proj_axis] = slice(None, sum_over)
+        slicing = tuple(slicing)
+
+        # Sum projection over the specified axis and plot
+        axes[row, proj_axis].imshow(field[slicing].sum(axis=proj_axis) + 1, 
+                                    cmap='magma', extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
+        axes[row, proj_axis].set_xlabel('Mpc/h')
+        axes[row, proj_axis].set_ylabel('Mpc/h')
+        axes[row, proj_axis].set_title(title)
+
+    # Plot each field across the three axes
+    for i, (name, field) in enumerate(fields_dict.items()):
+        for proj_axis in range(3):
+            plot_subplots(proj_axis, field, i, f"{name} projection {proj_axis}")
+
+    plt.tight_layout()
+    plt.show()