{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "code", "execution_count": 8, "id": "9Jy5BL1XiK1s", "metadata": { "id": "9Jy5BL1XiK1s" }, "outputs": [], "source": [ "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git" ] }, { "cell_type": "code", "execution_count": 1, "id": "c5f42bbe", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c5f42bbe", "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "import os\n", "import jax\n", "import jax.numpy as jnp\n", "import jax_cosmo as jc\n", "\n", "from jax.experimental.ode import odeint\n", "\n", "from jaxpm.painting import cic_paint\n", "from jaxpm.pm import linear_field, lpt, make_ode_fn" ] }, { "cell_type": "code", "execution_count": 2, "id": "281b4d3b", "metadata": { "id": "281b4d3b" }, "outputs": [], "source": [ "mesh_shape= [256, 256, 256]\n", "box_size = [256.,256.,256.]\n", "snapshots = jnp.linspace(0.1,1.,3)\n", "\n", "@jax.jit\n", "def run_simulation(omega_c, sigma8):\n", " # Create a small function to generate the matter power spectrum\n", " k = jnp.logspace(-4, 1, 128)\n", " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n", "\n", " # Create initial conditions\n", " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n", "\n", " # Create particles\n", " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1 , 3])\n", " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", " \n", " # Initial displacement\n", " dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n", " \n", " # Evolve the simulation forward\n", " res = odeint(make_ode_fn(mesh_shape,particles), [particles + dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n", " \n", " # Return the simulation volume at requested \n", "\n", " return initial_conditions , particles + dx , res[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "826be667", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "826be667", "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1" }, "outputs": [], "source": [ "initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)\n", "%timeit initial_conditions , lpt_particles , ode_particles = run_simulation(0.25, 0.8)" ] }, { "cell_type": "code", "execution_count": null, "id": "4e012ce8", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 323 }, "id": "4e012ce8", "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape of grid_mesh: (256, 256, 256)\n" ] } ], "source": [ "from visualize import plot_fields\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n", "for i , field in enumerate(ode_particles):\n", " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n", "plot_fields(fields)" ] }, { "cell_type": "code", "execution_count": null, "id": "b71824ed", "metadata": {}, "outputs": [], "source": [ "mesh_shape= [256, 256, 256]\n", "box_size = [256.,256.,256.]\n", "snapshots = jnp.linspace(0.1,1.,3)\n", "\n", "@jax.jit\n", "def run_simulation(omega_c, sigma8):\n", " # Create a small function to generate the matter power spectrum\n", " k = jnp.logspace(-4, 1, 128)\n", " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n", "\n", " # Create initial conditions\n", " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n", "\n", " # Create particles\n", " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", " \n", " # Initial displacement\n", " dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n", " \n", " # Evolve the simulation forward\n", " res = odeint(make_ode_fn(mesh_shape), [dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n", " \n", " # Return the simulation volume at requested \n", "\n", " return initial_conditions , dx , res[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "e9c9fd56", "metadata": {}, "outputs": [], "source": [ "initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)\n", "%timeit initial_conditions , lpt_displacements , ode_displacements = run_simulation(0.25, 0.8)" ] }, { "cell_type": "code", "execution_count": null, "id": "33b5e684", "metadata": {}, "outputs": [], "source": [ "from visualize import plot_fields\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n", "for i , field in enumerate(ode_displacements):\n", " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field)\n", "plot_fields(fields)" ] }, { "cell_type": "code", "execution_count": null, "id": "4e050871", "metadata": { "id": "4e050871" }, "outputs": [], "source": [ "!pip install diffrax\n", "from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve" ] }, { "cell_type": "code", "execution_count": null, "id": "43504a1b", "metadata": {}, "outputs": [], "source": [ "mesh_shape= [256, 256, 256]\n", "box_size = [256.,256.,256.]\n", "snapshots = jnp.linspace(0.1,1.,3)\n", "\n", "@jax.jit\n", "def run_simulation(omega_c, sigma8):\n", " # Create a small function to generate the matter power spectrum\n", " k = jnp.logspace(-4, 1, 128)\n", " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", " pk_fn = lambda x: jnp.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n", "\n", " # Create initial conditions\n", " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n", "\n", " # Create particles\n", " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", " \n", " # Initial displacement\n", " dx, p, f = lpt(cosmo, initial_conditions, 0.1)\n", " \n", " # Evolve the simulation forward\n", " ode_fn = make_ode_fn(mesh_shape)\n", " term = ODETerm(\n", " lambda t, state, args: jnp.stack(ode_fn(state, t, args), axis=0))\n", " solver = LeapfrogMidpoint()\n", "\n", " stepsize_controller = ConstantStepSize()\n", " res = diffeqsolve(term,\n", " solver,\n", " t0=0.1,\n", " t1=1.,\n", " dt0=0.01,\n", " y0=jnp.stack([dx, p], axis=0),\n", " args=cosmo,\n", " saveat=SaveAt(ts=snapshots),\n", " stepsize_controller=stepsize_controller)\n", "\n", "\n", " return initial_conditions , dx , res.ys , res.stats" ] }, { "cell_type": "code", "execution_count": null, "id": "19949ff1", "metadata": {}, "outputs": [], "source": [ "initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n", "%timeit initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8)\n", "print(f\"Solver Stats : {solver_stats}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "76a26e98", "metadata": {}, "outputs": [], "source": [ "from visualize import plot_fields\n", "\n", "fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_displacements)}\n", "for i , field in enumerate(ode_solutions):\n", " fields[f\"field_{i}\"] = cic_paint(jnp.zeros(mesh_shape) , field[0])\n", "plot_fields(fields)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "include_colab_link": true, "name": "Introduction.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }