{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "code", "execution_count": 8, "id": "9Jy5BL1XiK1s", "metadata": { "id": "9Jy5BL1XiK1s" }, "outputs": [], "source": [ "!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git@ASKabalan/jaxdecomp_proto" ] }, { "cell_type": "code", "execution_count": 1, "id": "c5f42bbe", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c5f42bbe", "outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n", "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "%pylab inline\n", "import os\n", "import jax\n", "import jax.numpy as jnp\n", "import jax_cosmo as jc\n", "\n", "from jax.experimental.ode import odeint\n", "\n", "from jaxpm.painting import cic_paint\n", "from jaxpm.pm import linear_field, lpt, make_ode_fn" ] }, { "cell_type": "code", "execution_count": 2, "id": "281b4d3b", "metadata": { "id": "281b4d3b" }, "outputs": [], "source": [ "mesh_shape= [256, 256, 256]\n", "box_size = [256.,256.,256.]\n", "snapshots = jnp.linspace(0.1,1.,2)\n", "\n", "@jax.jit\n", "def run_simulation(omega_c, sigma8):\n", " # Create a small function to generate the matter power spectrum\n", " k = jnp.logspace(-4, 1, 128)\n", " pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n", " pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n", "\n", " # Create initial conditions\n", " initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n", "\n", " # Create particles\n", " particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])\n", "\n", " cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n", " \n", " # Initial displacement\n", " dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n", " \n", " # Evolve the simulation forward\n", " res = odeint(make_ode_fn(mesh_shape), [particles+dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n", " \n", " # Return the simulation volume at requested \n", " return res[0]" ] }, { "cell_type": "code", "execution_count": 3, "id": "826be667", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "826be667", "outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "particles are Tracedwith\n", "pm_forces particles are Tracedwith\n", "shape of displacement: (256, 256, 256)\n", "pm_forces particles are Tracedwith\n" ] } ], "source": [ "res = run_simulation(0.25, 0.8)\n", "#%timeit res = run_simulation(0.25, 0.8)" ] }, { "cell_type": "code", "execution_count": null, "id": "4e012ce8", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 323 }, "id": "4e012ce8", "outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "shape of grid_mesh: (256, 256, 256)\n" ] } ], "source": [ "figure(figsize=[10,5])\n", "subplot(121)\n", "imshow(cic_paint(jnp.zeros(mesh_shape), res[0]).sum(axis=0),cmap='magma')\n", "subplot(122)\n", "imshow(cic_paint(jnp.zeros(mesh_shape), res[1]).sum(axis=0),cmap='magma')" ] }, { "cell_type": "code", "execution_count": null, "id": "4e050871", "metadata": { "id": "4e050871" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "include_colab_link": true, "name": "Introduction.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" } }, "nbformat": 4, "nbformat_minor": 5 }