diff --git a/notebooks/01-Introduction.ipynb b/notebooks/01-Introduction.ipynb new file mode 100644 index 0000000..8b8dc8a --- /dev/null +++ b/notebooks/01-Introduction.ipynb @@ -0,0 +1,320 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9Jy5BL1XiK1s", + "metadata": { + "id": "9Jy5BL1XiK1s" + }, + "outputs": [], + "source": [ + "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c5f42bbe", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c5f42bbe", + "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], + "source": [ + "import os\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax_cosmo as jc\n", + "\n", + "from jax.experimental.ode import odeint\n", + "\n", + "from jaxpm.painting import cic_paint\n", + "from jaxpm.pm import linear_field, lpt, make_ode_fn" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "281b4d3b", + "metadata": { + "id": "281b4d3b" + }, + "outputs": [], + "source": [ + "mesh_shape= [256, 256, 256]\n", + "box_size = [256.,256.,256.]\n", + "snapshots = jnp.linspace(0.1,1.,3)\n", + "\n", + "@jax.jit\n", + "def run_simulation(omega_c, sigma8):\n", + " # Create a small function to generate the matter power spectrum\n", + " k = jnp.logspace(-4, 1, 128)\n", + " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", + " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n", + "\n", + " # Create initial conditions\n", + " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n", + "\n", + " # Create particles\n", + " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1 , 3])\n", + " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", + " \n", + " # Initial displacement\n", + " dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n", + " \n", + " # Evolve the simulation forward\n", + " res = odeint(make_ode_fn(mesh_shape,particles), [particles + dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n", + " \n", + " # Return the simulation volume at requested \n", + "\n", + " return initial_conditions , particles + dx , res[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "826be667", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "826be667", + "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1" + }, + "outputs": [], + "source": [ + "initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)\n", + "%timeit initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e012ce8", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 323 + }, + "id": "4e012ce8", + "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape of grid_mesh: (256, 256, 256)\n" + ] + } + ], + "source": [ + "from visualize import plot_fields\n", + "\n", + "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n", + "for i , field in enumerate(ode_particles):\n", + " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n", + "plot_fields(fields)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b71824ed", + "metadata": {}, + "outputs": [], + "source": [ + "mesh_shape= [256, 256, 256]\n", + "box_size = [256.,256.,256.]\n", + "snapshots = jnp.linspace(0.1,1.,3)\n", + "\n", + "@jax.jit\n", + "def run_simulation(omega_c, sigma8):\n", + " # Create a small function to generate the matter power spectrum\n", + " k = jnp.logspace(-4, 1, 128)\n", + " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", + " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n", + "\n", + " # Create initial conditions\n", + " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n", + "\n", + " # Create particles\n", + " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", + " \n", + " # Initial displacement\n", + " dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n", + " \n", + " # Evolve the simulation forward\n", + " res = odeint(make_ode_fn(mesh_shape), [dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n", + " \n", + " # Return the simulation volume at requested \n", + "\n", + " return initial_conditions , dx , res[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9c9fd56", + "metadata": {}, + "outputs": [], + "source": [ + "initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)\n", + "%timeit initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33b5e684", + "metadata": {}, + "outputs": [], + "source": [ + "from visualize import plot_fields\n", + "\n", + "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n", + "for i , field in enumerate(ode_displacements):\n", + " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n", + "plot_fields(fields)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e050871", + "metadata": { + "id": "4e050871" + }, + "outputs": [], + "source": [ + "!pip install diffrax\n", + "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43504a1b", + "metadata": {}, + "outputs": [], + "source": [ + "mesh_shape= [256, 256, 256]\n", + "box_size = [256.,256.,256.]\n", + "snapshots = jnp.linspace(0.1,1.,3)\n", + "\n", + "@jax.jit\n", + "def run_simulation(omega_c, sigma8):\n", + " # Create a small function to generate the matter power spectrum\n", + " k = jnp.logspace(-4, 1, 128)\n", + " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", + " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n", + "\n", + " # Create initial conditions\n", + " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n", + "\n", + " # Create particles\n", + " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", + " \n", + " # Initial displacement\n", + " dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n", + " \n", + " # Evolve the simulation forward\n", + " ode_fn = make_ode_fn(mesh_shape)\n", + " term = ODETerm(\n", + " lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n", + " solver = LeapfrogMidpoint()\n", + "\n", + " stepsize_controller = ConstantStepSize()\n", + " res = diffeqsolve(term,\n", + " solver,\n", + " t0=0.1,\n", + " t1=1.,\n", + " dt0=0.01,\n", + " y0=jnp.stack([dx, p], axis=0),\n", + " args=cosmo,\n", + " saveat=SaveAt(ts=snapshots),\n", + " stepsize_controller=stepsize_controller)\n", + "\n", + "\n", + " return initial_conditions , dx , res.ys , res.stats" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19949ff1", + "metadata": {}, + "outputs": [], + "source": [ + "initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n", + "%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n", + "print(f\"Solver Stats : {solver_stats}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76a26e98", + "metadata": {}, + "outputs": [], + "source": [ + "from visualize import plot_fields\n", + "\n", + "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n", + "for i , field in enumerate(ode_solutions):\n", + " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field[0])\n", + "plot_fields(fields)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "name": "Introduction.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/02-MultiGPU_PM.ipynb b/notebooks/02-MultiGPU_PM.ipynb new file mode 100644 index 0000000..bf64b39 --- /dev/null +++ b/notebooks/02-MultiGPU_PM.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9Jy5BL1XiK1s", + "metadata": { + "id": "9Jy5BL1XiK1s" + }, + "outputs": [], + "source": [ + "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n", + "!pip install diffrax" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c5f42bbe", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c5f42bbe", + "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], + "source": [ + "import os\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax_cosmo as jc\n", + "\n", + "from jax.experimental.ode import odeint\n", + "from jaxpm.kernels import interpolate_power_spectrum\n", + "from jaxpm.painting import cic_paint\n", + "from jaxpm.pm import linear_field, lpt, make_ode_fn\n", + "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38df34e3", + "metadata": {}, + "outputs": [], + "source": [ + "assert jax.device_count() >= 8, \"This notebook requires a TPU or GPU runtime with 8 devices\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9edd2246", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental.mesh_utils import create_device_mesh\n", + "from jax.experimental.multihost_utils import process_allgather\n", + "from jax.sharding import Mesh, NamedSharding\n", + "from jax.sharding import PartitionSpec as P\n", + "from functools import partial\n", + "\n", + "all_gather = partial(process_allgather, tiled=True)\n", + "\n", + "pdims = (2, 4)\n", + "devices = create_device_mesh(pdims)\n", + "mesh = Mesh(devices, axis_names=('x', 'y'))\n", + "sharding = NamedSharding(mesh, P('x', 'y'))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "281b4d3b", + "metadata": { + "id": "281b4d3b" + }, + "outputs": [], + "source": [ + "mesh_shape = [1024, 1024, 1024]\n", + "box_size = [1024., 1024., 1024.]\n", + "halo_size = 128\n", + "snapshots = jnp.linspace(0.1, 1., 3)\n", + "\n", + "\n", + "@jax.jit\n", + "def run_simulation(omega_c, sigma8):\n", + " # Create a small function to generate the matter power spectrum\n", + " k = jnp.logspace(-4, 1, 128)\n", + " pk = jc.power.linear_matter_power(\n", + " jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", + " pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n", + "\n", + " # Create initial conditions\n", + " initial_conditions = linear_field(mesh_shape,\n", + " box_size,\n", + " pk_fn,\n", + " seed=jax.random.PRNGKey(0),\n", + " sharding=sharding)\n", + "\n", + " # Create particles\n", + " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),\n", + " axis=-1).reshape([-1, 3])\n", + " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", + "\n", + " # Initial displacement\n", + " dx, p, f = lpt(cosmo,\n", + " initial_conditions,\n", + " particles,\n", + " 0.1,\n", + " halo_size=halo_size,\n", + " sharding=sharding)\n", + "\n", + " # Evolve the simulation forward\n", + " ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n", + " term = ODETerm(\n", + " lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n", + " solver = LeapfrogMidpoint()\n", + "\n", + " stepsize_controller = ConstantStepSize()\n", + " res = diffeqsolve(term,\n", + " solver,\n", + " t0=0.1,\n", + " t1=1.,\n", + " dt0=0.01,\n", + " y0=jnp.stack([dx, p], axis=0),\n", + " args=cosmo,\n", + " saveat=SaveAt(ts=snapshots),\n", + " stepsize_controller=stepsize_controller)\n", + "\n", + " return initial_conditions, dx, res.ys, res.stats" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "826be667", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "826be667", + "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1" + }, + "outputs": [], + "source": [ + "initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n", + "%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n", + "print(f\"Solver Stats : {solver_stats}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "042cc55c", + "metadata": {}, + "outputs": [], + "source": [ + "initial_conditions = all_gather(initial_conditions)\n", + "lpt_particles = all_gather(lpt_particles)\n", + "ode_particles = [all_gather(p) for p in ode_particles]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e012ce8", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 323 + }, + "id": "4e012ce8", + "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape of grid_mesh: (256, 256, 256)\n" + ] + } + ], + "source": [ + "from visualize import plot_fields\n", + "\n", + "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n", + "for i , field in enumerate(ode_particles):\n", + " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n", + "plot_fields(fields)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "name": "Introduction.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/03-MultiHost_PM.ipynb b/notebooks/03-MultiHost_PM.ipynb new file mode 100644 index 0000000..b7f1683 --- /dev/null +++ b/notebooks/03-MultiHost_PM.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "22803ddc", + "metadata": { + "colab_type": "text", + "id": "view-in-github" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9Jy5BL1XiK1s", + "metadata": { + "id": "9Jy5BL1XiK1s" + }, + "outputs": [], + "source": [ + "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n", + "!pip install diffrax" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c5f42bbe", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c5f42bbe", + "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", + "Populating the interactive namespace from numpy and matplotlib\n" + ] + } + ], + "source": [ + "!salloc --account=tkc@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:30:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ebdfc00", + "metadata": {}, + "outputs": [], + "source": [ + "!squeue -u $USER" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c014316c", + "metadata": {}, + "outputs": [], + "source": [ + "export JOB_ID=123456" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7eabac5", + "metadata": {}, + "outputs": [], + "source": [ + "!srun --jobid=$JOB_ID -n 16 python 03-MultiHost_PM.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "472dd4bf", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "data = np.load(\"multihost_pm.npz\")\n", + "initial_conditions = data['initial_conditions']\n", + "lpt_displacements = data['lpt_displacements']\n", + "ode_solutions = data['ode_solutions']\n", + "solver_stats = data['solver_stats']\n", + "print(f\"Solver stats: {solver_stats}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e012ce8", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 323 + }, + "id": "4e012ce8", + "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape of grid_mesh: (256, 256, 256)\n" + ] + } + ], + "source": [ + "from visualize import plot_fields\n", + "\n", + "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n", + "for i , field in enumerate(ode_particles):\n", + " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n", + "plot_fields(fields)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "include_colab_link": true, + "name": "Introduction.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/03-MultiHost_PM.py b/notebooks/03-MultiHost_PM.py new file mode 100644 index 0000000..06dccb7 --- /dev/null +++ b/notebooks/03-MultiHost_PM.py @@ -0,0 +1,104 @@ +import os + +os.environ["EQX_ON_ERROR"] = "nan" # avoid an allgather caused by diffrax +import jax + +jax.distributed.initialize() +rank = jax.process_index() +size = jax.process_count() + +import jax.numpy as jnp +import jax_cosmo as jc + +from jaxpm.kernels import interpolate_power_spectrum +from jaxpm.painting import cic_paint_dx +from jaxpm.pm import linear_field, lpt, make_ode_fn + +from jax.experimental.mesh_utils import create_device_mesh +from jax.experimental.multihost_utils import process_allgather +from jax.sharding import Mesh, NamedSharding +from jax.sharding import PartitionSpec as P +from functools import partial +import numpy as np + +from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve + +all_gather = partial(process_allgather, tiled=True) + +pdims = (2, 4) +devices = create_device_mesh(pdims) +mesh = Mesh(devices, axis_names=('x', 'y')) +sharding = NamedSharding(mesh, P('x', 'y')) + +mesh_shape = [2024, 1024, 1024] +box_size = [1024., 1024., 1024.] +halo_size = 512 +snapshots = jnp.linspace(0.1, 1., 2) + + +@jax.jit +def run_simulation(omega_c, sigma8): + # Create a small function to generate the matter power spectrum + k = jnp.logspace(-4, 1, 128) + pk = jc.power.linear_matter_power( + jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) + pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding) + + # Create initial conditions + initial_conditions = linear_field(mesh_shape, + box_size, + pk_fn, + seed=jax.random.PRNGKey(0), + sharding=sharding) + + # Create particles + particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]), + axis=-1).reshape([-1, 3]) + cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8) + + # Initial displacement + dx, p, _ = lpt(cosmo, + initial_conditions, + particles, + 0.1, + halo_size=halo_size, + sharding=sharding) + + # Evolve the simulation forward + ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding) + term = ODETerm( + lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0)) + solver = LeapfrogMidpoint() + + stepsize_controller = ConstantStepSize() + res = diffeqsolve(term, + solver, + t0=0.1, + t1=1., + dt0=0.01, + y0=jnp.stack([dx, p], axis=0), + args=cosmo, + saveat=SaveAt(ts=snapshots), + stepsize_controller=stepsize_controller) + + return initial_conditions, dx, res.ys, res.stats + + +initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8) +print(f"[{rank}] Simulation completed") +print(f"[{rank}] Solver stats: {solver_stats}") + +# Gather the results +initial_conditions = all_gather(initial_conditions) +lpt_displacements = all_gather(lpt_displacements) +ode_solutions = [all_gather(sol) for sol in ode_solutions] + +if rank == 0: + np.savez("multihost_pm.npz", + initial_conditions=initial_conditions, + lpt_displacements=lpt_displacements, + ode_solutions=ode_solutions, + solver_stats=solver_stats) + +print(f"[{rank}] Simulation results saved") + \ No newline at end of file diff --git a/notebooks/Introduction.ipynb b/notebooks/Introduction.ipynb deleted file mode 100644 index aa49d41..0000000 --- a/notebooks/Introduction.ipynb +++ /dev/null @@ -1,192 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "view-in-github" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "9Jy5BL1XiK1s", - "metadata": { - "id": "9Jy5BL1XiK1s" - }, - "outputs": [], - "source": [ - "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git@ASKabalan/jaxdecomp_proto" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "c5f42bbe", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "c5f42bbe", - "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", - "Populating the interactive namespace from numpy and matplotlib\n" - ] - } - ], - "source": [ - "%pylab inline\n", - "import os\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import jax_cosmo as jc\n", - "\n", - "from jax.experimental.ode import odeint\n", - "\n", - "from jaxpm.painting import cic_paint\n", - "from jaxpm.pm import linear_field, lpt, make_ode_fn" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "281b4d3b", - "metadata": { - "id": "281b4d3b" - }, - "outputs": [], - "source": [ - "mesh_shape= [256, 256, 256]\n", - "box_size = [256.,256.,256.]\n", - "snapshots = jnp.linspace(0.1,1.,2)\n", - "\n", - "@jax.jit\n", - "def run_simulation(omega_c, sigma8):\n", - " # Create a small function to generate the matter power spectrum\n", - " k = jnp.logspace(-4, 1, 128)\n", - " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", - " pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n", - "\n", - " # Create initial conditions\n", - " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n", - "\n", - " # Create particles\n", - " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])\n", - "\n", - " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", - " \n", - " # Initial displacement\n", - " dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n", - " \n", - " # Evolve the simulation forward\n", - " res = odeint(make_ode_fn(mesh_shape), [particles+dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n", - " \n", - " # Return the simulation volume at requested \n", - " return res[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "826be667", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "826be667", - "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "particles are Tracedwith\n", - "pm_forces particles are Tracedwith\n", - "shape of displacement: (256, 256, 256)\n", - "pm_forces particles are Tracedwith\n" - ] - } - ], - "source": [ - "res = run_simulation(0.25, 0.8)\n", - "#%timeit res = run_simulation(0.25, 0.8)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4e012ce8", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 323 - }, - "id": "4e012ce8", - "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "shape of grid_mesh: (256, 256, 256)\n" - ] - } - ], - "source": [ - "figure(figsize=[10,5])\n", - "subplot(121)\n", - "imshow(cic_paint(jnp.zeros(mesh_shape), res[0]).sum(axis=0),cmap='magma')\n", - "subplot(122)\n", - "imshow(cic_paint(jnp.zeros(mesh_shape), res[1]).sum(axis=0),cmap='magma')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4e050871", - "metadata": { - "id": "4e050871" - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "include_colab_link": true, - "name": "Introduction.ipynb", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/visualize.py b/notebooks/visualize.py new file mode 100644 index 0000000..0db2273 --- /dev/null +++ b/notebooks/visualize.py @@ -0,0 +1,38 @@ +import numpy as np +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def plot_fields(fields_dict, sum_over=None): + """ + Plots sum projections of 3D fields along different axes, + slicing only the first `sum_over` elements along each axis. + + Args: + - fields: list of 3D arrays representing fields to plot + - names: list of names for each field, used in titles + - sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8) + """ + sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8 + nb_rows = len(fields_dict) + nb_cols = 3 + fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows)) + + def plot_subplots(proj_axis, field, row, title): + slicing = [slice(None)] * field.ndim + slicing[proj_axis] = slice(None, sum_over) + slicing = tuple(slicing) + + # Sum projection over the specified axis and plot + axes[row, proj_axis].imshow(field[slicing].sum(axis=proj_axis) + 1, + cmap='magma', extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]]) + axes[row, proj_axis].set_xlabel('Mpc/h') + axes[row, proj_axis].set_ylabel('Mpc/h') + axes[row, proj_axis].set_title(title) + + # Plot each field across the three axes + for i, (name, field) in enumerate(fields_dict.items()): + for proj_axis in range(3): + plot_subplots(proj_axis, field, i, f"{name} projection {proj_axis}") + + plt.tight_layout() + plt.show()