{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "code", "execution_count": 8, "id": "9Jy5BL1XiK1s", "metadata": { "id": "9Jy5BL1XiK1s" }, "outputs": [], "source": [ "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git\n", "!pip install diffrax" ] }, { "cell_type": "code", "execution_count": 1, "id": "c5f42bbe", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c5f42bbe", "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "import os\n", "import jax\n", "import jax.numpy as jnp\n", "import jax_cosmo as jc\n", "\n", "from jax.experimental.ode import odeint\n", "from jaxpm.kernels import interpolate_power_spectrum\n", "from jaxpm.painting import cic_paint\n", "from jaxpm.pm import linear_field, lpt, make_ode_fn\n", "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve" ] }, { "cell_type": "code", "execution_count": null, "id": "38df34e3", "metadata": {}, "outputs": [], "source": [ "assert jax.device_count() >= 8, \"This notebook requires a TPU or GPU runtime with 8 devices\"" ] }, { "cell_type": "code", "execution_count": null, "id": "9edd2246", "metadata": {}, "outputs": [], "source": [ "from jax.experimental.mesh_utils import create_device_mesh\n", "from jax.experimental.multihost_utils import process_allgather\n", "from jax.sharding import Mesh, NamedSharding\n", "from jax.sharding import PartitionSpec as P\n", "from functools import partial\n", "\n", "all_gather = partial(process_allgather, tiled=True)\n", "\n", "pdims = (2, 4)\n", "devices = create_device_mesh(pdims)\n", "mesh = Mesh(devices, axis_names=('x', 'y'))\n", "sharding = NamedSharding(mesh, P('x', 'y'))" ] }, { "cell_type": "code", "execution_count": 2, "id": "281b4d3b", "metadata": { "id": "281b4d3b" }, "outputs": [], "source": [ "mesh_shape = [1024, 1024, 1024]\n", "box_size = [1024., 1024., 1024.]\n", "halo_size = 128\n", "snapshots = jnp.linspace(0.1, 1., 3)\n", "\n", "\n", "@jax.jit\n", "def run_simulation(omega_c, sigma8):\n", " # Create a small function to generate the matter power spectrum\n", " k = jnp.logspace(-4, 1, 128)\n", " pk = jc.power.linear_matter_power(\n", " jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", " pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n", "\n", " # Create initial conditions\n", " initial_conditions = linear_field(mesh_shape,\n", " box_size,\n", " pk_fn,\n", " seed=jax.random.PRNGKey(0),\n", " sharding=sharding)\n", "\n", " # Create particles\n", " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),\n", " axis=-1).reshape([-1, 3])\n", " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", "\n", " # Initial displacement\n", " dx, p, f = lpt(cosmo,\n", " initial_conditions,\n", " particles,\n", " 0.1,\n", " halo_size=halo_size,\n", " sharding=sharding)\n", "\n", " # Evolve the simulation forward\n", " ode_fn = make_ode_fn(mesh_shape, halo_size=halo_size, sharding=sharding)\n", " term = ODETerm(\n", " lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", " res = diffeqsolve(term,\n", " solver,\n", " t0=0.1,\n", " t1=1.,\n", " dt0=0.01,\n", " y0=jnp.stack([dx, p], axis=0),\n", " args=cosmo,\n", " saveat=SaveAt(ts=snapshots),\n", " stepsize_controller=stepsize_controller)\n", "\n", " return initial_conditions, dx, res.ys, res.stats" ] }, { "cell_type": "code", "execution_count": null, "id": "826be667", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "826be667", "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1" }, "outputs": [], "source": [ "initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n", "%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n", "print(f\"Solver Stats : {solver_stats}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "042cc55c", "metadata": {}, "outputs": [], "source": [ "initial_conditions = all_gather(initial_conditions)\n", "lpt_particles = all_gather(lpt_particles)\n", "ode_particles = [all_gather(p) for p in ode_particles]" ] }, { "cell_type": "code", "execution_count": null, "id": "4e012ce8", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 323 }, "id": "4e012ce8", "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape of grid_mesh: (256, 256, 256)\n" ] } ], "source": [ "from visualize import plot_fields\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n", "for i , field in enumerate(ode_particles):\n", " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n", "plot_fields(fields)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "include_colab_link": true, "name": "Introduction.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }