{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"
"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9Jy5BL1XiK1s",
"metadata": {
"id": "9Jy5BL1XiK1s"
},
"outputs": [],
"source": [
"!pip install --quiet git+https://github.com/DifferentiableUniverseInitiative/JaxPM.git@ASKabalan/jaxdecomp_proto"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5f42bbe",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c5f42bbe",
"outputId": "a7841b28-5f20-4856-bd1d-f8a3572095b5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab inline\n",
"import os\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_cosmo as jc\n",
"\n",
"from jax.experimental.ode import odeint\n",
"\n",
"from jaxpm.painting import cic_paint\n",
"from jaxpm.pm import linear_field, lpt, make_ode_fn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "281b4d3b",
"metadata": {
"id": "281b4d3b"
},
"outputs": [],
"source": [
"mesh_shape= [256, 256, 256]\n",
"box_size = [256.,256.,256.]\n",
"snapshots = jnp.linspace(0.1,1.,2)\n",
"\n",
"@jax.jit\n",
"def run_simulation(omega_c, sigma8):\n",
" # Create a small function to generate the matter power spectrum\n",
" k = jnp.logspace(-4, 1, 128)\n",
" pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
" pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk).reshape(x.shape)\n",
"\n",
" # Create initial conditions\n",
" initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=jax.random.PRNGKey(0))\n",
"\n",
" # Create particles\n",
" particles = jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape]),axis=-1).reshape([-1,3])\n",
"\n",
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
" \n",
" # Initial displacement\n",
" dx, p, f = lpt(cosmo, initial_conditions, particles, 0.1)\n",
" \n",
" # Evolve the simulation forward\n",
" res = odeint(make_ode_fn(mesh_shape), [particles+dx, p], snapshots, cosmo, rtol=1e-5, atol=1e-5)\n",
" \n",
" # Return the simulation volume at requested \n",
" return res[0]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "826be667",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "826be667",
"outputId": "dc43b5c4-a004-41bf-f2c8-128c17cf4de1"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"particles are Tracedwith\n",
"pm_forces particles are Tracedwith\n",
"shape of displacement: (256, 256, 256)\n",
"pm_forces particles are Tracedwith\n"
]
}
],
"source": [
"res = run_simulation(0.25, 0.8)\n",
"#%timeit res = run_simulation(0.25, 0.8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e012ce8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 323
},
"id": "4e012ce8",
"outputId": "75390318-8072-481f-ffb9-ec09cd71cb1d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape of grid_mesh: (256, 256, 256)\n"
]
}
],
"source": [
"figure(figsize=[10,5])\n",
"subplot(121)\n",
"imshow(cic_paint(jnp.zeros(mesh_shape), res[0]).sum(axis=0),cmap='magma')\n",
"subplot(122)\n",
"imshow(cic_paint(jnp.zeros(mesh_shape), res[1]).sum(axis=0),cmap='magma')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e050871",
"metadata": {
"id": "4e050871"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"name": "Introduction.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}