diff --git a/notebooks/01-Introduction.ipynb b/notebooks/01-Introduction.ipynb
new file mode 100644
index 0000000..8b8dc8a
--- /dev/null
+++ b/notebooks/01-Introduction.ipynb
@@ -0,0 +1,320 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "9Jy5BL1XiK1s",
+ "metadata": {
+ "id": "9Jy5BL1XiK1s"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "c5f42bbe",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "c5f42bbe",
+ "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
+ "Populating the interactive namespace from numpy and matplotlib\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import jax_cosmo as jc\n",
+ "\n",
+ "from jax.experimental.ode import odeint\n",
+ "\n",
+ "from jaxpm.painting import cic_paint\n",
+ "from jaxpm.pm import linear_field, lpt, make_ode_fn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "281b4d3b",
+ "metadata": {
+ "id": "281b4d3b"
+ },
+ "outputs": [],
+ "source": [
+ "mesh_shape= [256, 256, 256]\n",
+ "box_size = [256.,256.,256.]\n",
+ "snapshots = jnp.linspace(0.1,1.,3)\n",
+ "\n",
+ "@jax.jit\n",
+ "def run_simulation(omega_c, sigma8):\n",
+ " # Create a small function to generate the matter power spectrum\n",
+ " k = jnp.logspace(-4, 1, 128)\n",
+ " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
+ " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
+ "\n",
+ " # Create initial conditions\n",
+ " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
+ "\n",
+ " # Create particles\n",
+ " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1 , 3])\n",
+ " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
+ " \n",
+ " # Initial displacement\n",
+ " dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
+ " \n",
+ " # Evolve the simulation forward\n",
+ " res = odeint(make_ode_fn(mesh_shape,particles), [particles + dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
+ " \n",
+ " # Return the simulation volume at requested \n",
+ "\n",
+ " return initial_conditions , particles + dx , res[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "826be667",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "826be667",
+ "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
+ },
+ "outputs": [],
+ "source": [
+ "initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)\n",
+ "%timeit initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4e012ce8",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 323
+ },
+ "id": "4e012ce8",
+ "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "shape of grid_mesh: (256, 256, 256)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from visualize import plot_fields\n",
+ "\n",
+ "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
+ "for i , field in enumerate(ode_particles):\n",
+ " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
+ "plot_fields(fields)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b71824ed",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh_shape= [256, 256, 256]\n",
+ "box_size = [256.,256.,256.]\n",
+ "snapshots = jnp.linspace(0.1,1.,3)\n",
+ "\n",
+ "@jax.jit\n",
+ "def run_simulation(omega_c, sigma8):\n",
+ " # Create a small function to generate the matter power spectrum\n",
+ " k = jnp.logspace(-4, 1, 128)\n",
+ " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
+ " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
+ "\n",
+ " # Create initial conditions\n",
+ " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
+ "\n",
+ " # Create particles\n",
+ " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
+ " \n",
+ " # Initial displacement\n",
+ " dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n",
+ " \n",
+ " # Evolve the simulation forward\n",
+ " res = odeint(make_ode_fn(mesh_shape), [dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
+ " \n",
+ " # Return the simulation volume at requested \n",
+ "\n",
+ " return initial_conditions , dx , res[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e9c9fd56",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)\n",
+ "%timeit initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "33b5e684",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from visualize import plot_fields\n",
+ "\n",
+ "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n",
+ "for i , field in enumerate(ode_displacements):\n",
+ " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
+ "plot_fields(fields)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4e050871",
+ "metadata": {
+ "id": "4e050871"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install diffrax\n",
+ "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "43504a1b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mesh_shape= [256, 256, 256]\n",
+ "box_size = [256.,256.,256.]\n",
+ "snapshots = jnp.linspace(0.1,1.,3)\n",
+ "\n",
+ "@jax.jit\n",
+ "def run_simulation(omega_c, sigma8):\n",
+ " # Create a small function to generate the matter power spectrum\n",
+ " k = jnp.logspace(-4, 1, 128)\n",
+ " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
+ " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
+ "\n",
+ " # Create initial conditions\n",
+ " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
+ "\n",
+ " # Create particles\n",
+ " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
+ " \n",
+ " # Initial displacement\n",
+ " dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n",
+ " \n",
+ " # Evolve the simulation forward\n",
+ " ode_fn = make_ode_fn(mesh_shape)\n",
+ " term = ODETerm(\n",
+ " lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
+ " solver = LeapfrogMidpoint()\n",
+ "\n",
+ " stepsize_controller = ConstantStepSize()\n",
+ " res = diffeqsolve(term,\n",
+ " solver,\n",
+ " t0=0.1,\n",
+ " t1=1.,\n",
+ " dt0=0.01,\n",
+ " y0=jnp.stack([dx, p], axis=0),\n",
+ " args=cosmo,\n",
+ " saveat=SaveAt(ts=snapshots),\n",
+ " stepsize_controller=stepsize_controller)\n",
+ "\n",
+ "\n",
+ " return initial_conditions , dx , res.ys , res.stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "19949ff1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
+ "%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
+ "print(f\"Solver Stats : {solver_stats}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "76a26e98",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from visualize import plot_fields\n",
+ "\n",
+ "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n",
+ "for i , field in enumerate(ode_solutions):\n",
+ " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field[0])\n",
+ "plot_fields(fields)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "name": "Introduction.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/02-MultiGPU_PM.ipynb b/notebooks/02-MultiGPU_PM.ipynb
new file mode 100644
index 0000000..bf64b39
--- /dev/null
+++ b/notebooks/02-MultiGPU_PM.ipynb
@@ -0,0 +1,242 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "9Jy5BL1XiK1s",
+ "metadata": {
+ "id": "9Jy5BL1XiK1s"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
+ "!pip install diffrax"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "c5f42bbe",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "c5f42bbe",
+ "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
+ "Populating the interactive namespace from numpy and matplotlib\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import jax\n",
+ "import jax.numpy as jnp\n",
+ "import jax_cosmo as jc\n",
+ "\n",
+ "from jax.experimental.ode import odeint\n",
+ "from jaxpm.kernels import interpolate_power_spectrum\n",
+ "from jaxpm.painting import cic_paint\n",
+ "from jaxpm.pm import linear_field, lpt, make_ode_fn\n",
+ "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "38df34e3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "assert jax.device_count() >= 8, \"This notebook requires a TPU or GPU runtime with 8 devices\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9edd2246",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jax.experimental.mesh_utils import create_device_mesh\n",
+ "from jax.experimental.multihost_utils import process_allgather\n",
+ "from jax.sharding import Mesh, NamedSharding\n",
+ "from jax.sharding import PartitionSpec as P\n",
+ "from functools import partial\n",
+ "\n",
+ "all_gather = partial(process_allgather, tiled=True)\n",
+ "\n",
+ "pdims = (2, 4)\n",
+ "devices = create_device_mesh(pdims)\n",
+ "mesh = Mesh(devices, axis_names=('x', 'y'))\n",
+ "sharding = NamedSharding(mesh, P('x', 'y'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "281b4d3b",
+ "metadata": {
+ "id": "281b4d3b"
+ },
+ "outputs": [],
+ "source": [
+ "mesh_shape = [1024, 1024, 1024]\n",
+ "box_size = [1024., 1024., 1024.]\n",
+ "halo_size = 128\n",
+ "snapshots = jnp.linspace(0.1, 1., 3)\n",
+ "\n",
+ "\n",
+ "@jax.jit\n",
+ "def run_simulation(omega_c, sigma8):\n",
+ " # Create a small function to generate the matter power spectrum\n",
+ " k = jnp.logspace(-4, 1, 128)\n",
+ " pk = jc.power.linear_matter_power(\n",
+ " jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
+ " pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n",
+ "\n",
+ " # Create initial conditions\n",
+ " initial_conditions = linear_field(mesh_shape,\n",
+ " box_size,\n",
+ " pk_fn,\n",
+ " seed=jax.random.PRNGKey(0),\n",
+ " sharding=sharding)\n",
+ "\n",
+ " # Create particles\n",
+ " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),\n",
+ " axis=-1).reshape([-1, 3])\n",
+ " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
+ "\n",
+ " # Initial displacement\n",
+ " dx, p, f = lpt(cosmo,\n",
+ " initial_conditions,\n",
+ " particles,\n",
+ " 0.1,\n",
+ " halo_size=halo_size,\n",
+ " sharding=sharding)\n",
+ "\n",
+ " # Evolve the simulation forward\n",
+ " ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n",
+ " term = ODETerm(\n",
+ " lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n",
+ " solver = LeapfrogMidpoint()\n",
+ "\n",
+ " stepsize_controller = ConstantStepSize()\n",
+ " res = diffeqsolve(term,\n",
+ " solver,\n",
+ " t0=0.1,\n",
+ " t1=1.,\n",
+ " dt0=0.01,\n",
+ " y0=jnp.stack([dx, p], axis=0),\n",
+ " args=cosmo,\n",
+ " saveat=SaveAt(ts=snapshots),\n",
+ " stepsize_controller=stepsize_controller)\n",
+ "\n",
+ " return initial_conditions, dx, res.ys, res.stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "826be667",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "826be667",
+ "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
+ },
+ "outputs": [],
+ "source": [
+ "initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
+ "%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n",
+ "print(f\"Solver Stats : {solver_stats}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "042cc55c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "initial_conditions = all_gather(initial_conditions)\n",
+ "lpt_particles = all_gather(lpt_particles)\n",
+ "ode_particles = [all_gather(p) for p in ode_particles]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4e012ce8",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 323
+ },
+ "id": "4e012ce8",
+ "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "shape of grid_mesh: (256, 256, 256)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from visualize import plot_fields\n",
+ "\n",
+ "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
+ "for i , field in enumerate(ode_particles):\n",
+ " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
+ "plot_fields(fields)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "name": "Introduction.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/03-MultiHost_PM.ipynb b/notebooks/03-MultiHost_PM.ipynb
new file mode 100644
index 0000000..b7f1683
--- /dev/null
+++ b/notebooks/03-MultiHost_PM.ipynb
@@ -0,0 +1,157 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "22803ddc",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "9Jy5BL1XiK1s",
+ "metadata": {
+ "id": "9Jy5BL1XiK1s"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n",
+ "!pip install diffrax"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "c5f42bbe",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "c5f42bbe",
+ "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
+ "Populating the interactive namespace from numpy and matplotlib\n"
+ ]
+ }
+ ],
+ "source": [
+ "!salloc --account=tkc@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:30:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7ebdfc00",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!squeue -u $USER"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c014316c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "export JOB_ID=123456"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b7eabac5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!srun --jobid=$JOB_ID -n 16 python 03-MultiHost_PM.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "472dd4bf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "data = np.load(\"multihost_pm.npz\")\n",
+ "initial_conditions = data['initial_conditions']\n",
+ "lpt_displacements = data['lpt_displacements']\n",
+ "ode_solutions = data['ode_solutions']\n",
+ "solver_stats = data['solver_stats']\n",
+ "print(f\"Solver stats: {solver_stats}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4e012ce8",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 323
+ },
+ "id": "4e012ce8",
+ "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "shape of grid_mesh: (256, 256, 256)\n"
+ ]
+ }
+ ],
+ "source": [
+ "from visualize import plot_fields\n",
+ "\n",
+ "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
+ "for i , field in enumerate(ode_particles):\n",
+ " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n",
+ "plot_fields(fields)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "include_colab_link": true,
+ "name": "Introduction.ipynb",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/notebooks/03-MultiHost_PM.py b/notebooks/03-MultiHost_PM.py
new file mode 100644
index 0000000..06dccb7
--- /dev/null
+++ b/notebooks/03-MultiHost_PM.py
@@ -0,0 +1,104 @@
+import os
+
+os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax
+import jax
+
+jax.distributed.initialize()
+rank = jax.process_index()
+size = jax.process_count()
+
+import jax.numpy as jnp
+import jax_cosmo as jc
+
+from jaxpm.kernels import interpolate_power_spectrum
+from jaxpm.painting import cic_paint_dx
+from jaxpm.pm import linear_field, lpt, make_ode_fn
+
+from jax.experimental.mesh_utils import create_device_mesh
+from jax.experimental.multihost_utils import process_allgather
+from jax.sharding import Mesh, NamedSharding
+from jax.sharding import PartitionSpec as P
+from functools import partial
+import numpy as np
+
+from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
+
+all_gather = partial(process_allgather, tiled=True)
+
+pdims = (2, 4)
+devices = create_device_mesh(pdims)
+mesh = Mesh(devices, axis_names=('x', 'y'))
+sharding = NamedSharding(mesh, P('x', 'y'))
+
+mesh_shape = [2024, 1024, 1024]
+box_size = [1024., 1024., 1024.]
+halo_size = 512
+snapshots = jnp.linspace(0.1, 1., 2)
+
+
+@jax.jit
+def run_simulation(omega_c, sigma8):
+ # Create a small function to generate the matter power spectrum
+ k = jnp.logspace(-4, 1, 128)
+ pk = jc.power.linear_matter_power(
+ jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
+ pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
+
+ # Create initial conditions
+ initial_conditions = linear_field(mesh_shape,
+ box_size,
+ pk_fn,
+ seed=jax.random.PRNGKey(0),
+ sharding=sharding)
+
+ # Create particles
+ particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),
+ axis=-1).reshape([-1, 3])
+ cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
+
+ # Initial displacement
+ dx, p, _ = lpt(cosmo,
+ initial_conditions,
+ particles,
+ 0.1,
+ halo_size=halo_size,
+ sharding=sharding)
+
+ # Evolve the simulation forward
+ ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)
+ term = ODETerm(
+ lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))
+ solver = LeapfrogMidpoint()
+
+ stepsize_controller = ConstantStepSize()
+ res = diffeqsolve(term,
+ solver,
+ t0=0.1,
+ t1=1.,
+ dt0=0.01,
+ y0=jnp.stack([dx, p], axis=0),
+ args=cosmo,
+ saveat=SaveAt(ts=snapshots),
+ stepsize_controller=stepsize_controller)
+
+ return initial_conditions, dx, res.ys, res.stats
+
+
+initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)
+print(f"[{rank}] Simulation completed")
+print(f"[{rank}] Solver stats: {solver_stats}")
+
+# Gather the results
+initial_conditions = all_gather(initial_conditions)
+lpt_displacements = all_gather(lpt_displacements)
+ode_solutions = [all_gather(sol) for sol in ode_solutions]
+
+if rank == 0:
+ np.savez("multihost_pm.npz",
+ initial_conditions=initial_conditions,
+ lpt_displacements=lpt_displacements,
+ ode_solutions=ode_solutions,
+ solver_stats=solver_stats)
+
+print(f"[{rank}] Simulation results saved")
+
\ No newline at end of file
diff --git a/notebooks/Introduction.ipynb b/notebooks/Introduction.ipynb
deleted file mode 100644
index aa49d41..0000000
--- a/notebooks/Introduction.ipynb
+++ /dev/null
@@ -1,192 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "view-in-github"
- },
- "source": [
- "
"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "9Jy5BL1XiK1s",
- "metadata": {
- "id": "9Jy5BL1XiK1s"
- },
- "outputs": [],
- "source": [
- "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git@ASKabalan/jaxdecomp_proto"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "c5f42bbe",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "c5f42bbe",
- "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
- "Populating the interactive namespace from numpy and matplotlib\n"
- ]
- }
- ],
- "source": [
- "%pylab inline\n",
- "import os\n",
- "import jax\n",
- "import jax.numpy as jnp\n",
- "import jax_cosmo as jc\n",
- "\n",
- "from jax.experimental.ode import odeint\n",
- "\n",
- "from jaxpm.painting import cic_paint\n",
- "from jaxpm.pm import linear_field, lpt, make_ode_fn"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "281b4d3b",
- "metadata": {
- "id": "281b4d3b"
- },
- "outputs": [],
- "source": [
- "mesh_shape= [256, 256, 256]\n",
- "box_size = [256.,256.,256.]\n",
- "snapshots = jnp.linspace(0.1,1.,2)\n",
- "\n",
- "@jax.jit\n",
- "def run_simulation(omega_c, sigma8):\n",
- " # Create a small function to generate the matter power spectrum\n",
- " k = jnp.logspace(-4, 1, 128)\n",
- " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
- " pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
- "\n",
- " # Create initial conditions\n",
- " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
- "\n",
- " # Create particles\n",
- " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])\n",
- "\n",
- " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
- " \n",
- " # Initial displacement\n",
- " dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
- " \n",
- " # Evolve the simulation forward\n",
- " res = odeint(make_ode_fn(mesh_shape), [particles+dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
- " \n",
- " # Return the simulation volume at requested \n",
- " return res[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "826be667",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "826be667",
- "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "particles are Tracedwith\n",
- "pm_forces particles are Tracedwith\n",
- "shape of displacement: (256, 256, 256)\n",
- "pm_forces particles are Tracedwith\n"
- ]
- }
- ],
- "source": [
- "res = run_simulation(0.25, 0.8)\n",
- "#%timeit res = run_simulation(0.25, 0.8)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4e012ce8",
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 323
- },
- "id": "4e012ce8",
- "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "shape of grid_mesh: (256, 256, 256)\n"
- ]
- }
- ],
- "source": [
- "figure(figsize=[10,5])\n",
- "subplot(121)\n",
- "imshow(cic_paint(jnp.zeros(mesh_shape), res[0]).sum(axis=0),cmap='magma')\n",
- "subplot(122)\n",
- "imshow(cic_paint(jnp.zeros(mesh_shape), res[1]).sum(axis=0),cmap='magma')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4e050871",
- "metadata": {
- "id": "4e050871"
- },
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "include_colab_link": true,
- "name": "Introduction.ipynb",
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.4"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/notebooks/visualize.py b/notebooks/visualize.py
new file mode 100644
index 0000000..0db2273
--- /dev/null
+++ b/notebooks/visualize.py
@@ -0,0 +1,38 @@
+import numpy as np
+import jax.numpy as jnp
+import matplotlib.pyplot as plt
+
+def plot_fields(fields_dict, sum_over=None):
+ """
+ Plots sum projections of 3D fields along different axes,
+ slicing only the first `sum_over` elements along each axis.
+
+ Args:
+ - fields: list of 3D arrays representing fields to plot
+ - names: list of names for each field, used in titles
+ - sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
+ """
+ sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
+ nb_rows = len(fields_dict)
+ nb_cols = 3
+ fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
+
+ def plot_subplots(proj_axis, field, row, title):
+ slicing = [slice(None)] * field.ndim
+ slicing[proj_axis] = slice(None, sum_over)
+ slicing = tuple(slicing)
+
+ # Sum projection over the specified axis and plot
+ axes[row, proj_axis].imshow(field[slicing].sum(axis=proj_axis) + 1,
+ cmap='magma', extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
+ axes[row, proj_axis].set_xlabel('Mpc/h')
+ axes[row, proj_axis].set_ylabel('Mpc/h')
+ axes[row, proj_axis].set_title(title)
+
+ # Plot each field across the three axes
+ for i, (name, field) in enumerate(fields_dict.items()):
+ for proj_axis in range(3):
+ plot_subplots(proj_axis, field, i, f"{name} projection {proj_axis}")
+
+ plt.tight_layout()
+ plt.show()