JaxPM/dev/HamiltonianGNN.ipynb

899 lines
487 KiB
Text
Raw Normal View History

2022-04-28 00:20:44 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab inline\n",
"import camels_library as cl\n",
"import readgadget\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading CAMELS data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 34/34 [01:48<00:00, 3.19s/it]\n"
]
}
],
"source": [
"scales = []\n",
"poss = []\n",
"vels = []\n",
"\n",
"snapshot='/data/CAMELS/Sims/IllustrisTNG_DM/CV_0/snap_000.hdf5'\n",
"\n",
"header = readgadget.header(snapshot)\n",
"BoxSize = header.boxsize/1e3 #Mpc/h\n",
"Nall = header.nall #Total number of particles\n",
"Masses = header.massarr*1e10 #Masses of the particles in Msun/h\n",
"Omega_m = header.omega_m #value of Omega_m\n",
"Omega_l = header.omega_l #value of Omega_l\n",
"h = header.hubble #value of h\n",
"redshift = header.redshift #redshift of the snapshot\n",
"Hubble = 100.0*np.sqrt(Omega_m*(1.0+redshift)**3+Omega_l)#Value of H(z) in km/s/(Mpc/h)\n",
"\n",
"# Loading all the intermediate snapshots\n",
"for i in tqdm(range(34)):\n",
" snapshot='/data/CAMELS/Sims/IllustrisTNG_DM/CV_0/snap_%03d.hdf5'%i\n",
" \n",
" header = readgadget.header(snapshot)\n",
" \n",
" redshift = header.redshift #redshift of the snapshot\n",
" \n",
" ptype = [1] #dark matter is particle type 1\n",
" ids = np.argsort(readgadget.read_block(snapshot, \"ID \", ptype)-1) #IDs starting from 0\n",
" pos = readgadget.read_block(snapshot, \"POS \", ptype)[ids]/1e3 #positions in Mpc/h\n",
" vel = readgadget.read_block(snapshot, \"VEL \", ptype)[ids] #peculiar velocities in km/s\n",
"\n",
" # Reordering data for simple reshaping\n",
" pos = pos.reshape(4,4,4,64,64,64,3).transpose(0,3,1,4,2,5,6).reshape(-1,3)\n",
" vel = vel.reshape(4,4,4,64,64,64,3).transpose(0,3,1,4,2,5,6).reshape(-1,3)\n",
" \n",
" pos = (pos / BoxSize * 32).reshape([256,256,256,3])[2::8,2::8,2::8,:].reshape([-1,3])\n",
" vel = (vel / 100 * (1./(1+redshift)) / BoxSize*32).reshape([256,256,256,3])[2::8,2::8,2::8,:].reshape([-1,3])\n",
" \n",
" scales.append((1./(1+redshift)))\n",
" poss.append(pos)\n",
" vels.append(vel)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Runing PM simulation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_cosmo as jc\n",
"import haiku as hk\n",
"\n",
"from jax.experimental.ode import odeint\n",
"\n",
"from jaxpm.painting import cic_paint, cic_read, compensate_cic\n",
"from jaxpm.pm import linear_field, lpt, make_ode_fn, pm_forces\n",
"from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel\n",
"from jaxpm.utils import power_spectrum\n",
"\n",
"rng_seq = hk.PRNGSequence(1)\n",
"\n",
"mesh_shape= [32, 32, 32]\n",
"box_size = [25., 25., 25.]\n",
"cosmo = jc.Planck15(Omega_c= 0.3 - 0.049, Omega_b=0.049, n_s=0.9624, h=0.671 , sigma8=0.8)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Run the simulation with Particle-Mesh scheme\n",
"resi = odeint(make_ode_fn(mesh_shape), [poss[0], vels[0]], jnp.array(scales), cosmo, rtol=1e-5, atol=1e-5)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkwAAAJBCAYAAAC0+uodAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9eZxddX3///7cO3f2JbNkGSaZLJMNQiAESEAQCUpEUMEFW7CCLZZvrRWtbX+gX63a2hbbaq1atRQX/FZQQQQEhCCGJSwhIQQSsk9IJstMklkz+9yZe35/kJzzer/u3HNmywXJ+/l48OC887n3LJ/zPp975vN+fd5v53meGIZhGIZhGJmJvdEnYBiGYRiG8WbHXpgMwzAMwzAisBcmwzAMwzCMCOyFyTAMwzAMIwJ7YTIMwzAMw4jAXpgMwzAMwzAiGNcLk3PuMufcdufcLufcLRN1UsbJhfmRMV7Mh4yJwPzICMONNQ+Tcy4uIjtE5FIR2S8i60TkGs/ztkzc6RlvdcyPjPFiPmRMBOZHRhQ54/juMhHZ5XnebhER59zPReRKEcnoXCUlRV5l5STfTqUK/G3nkuqznqdPLZUqovYEfHdAtbGdn/+avz0wUE37yVP24GAp7Svlb8divRJGKlUI2/pFNEZzeZ6n/yEe74LvFqi2gYF++mwu7TsO+9XHLSraRvuqAWtIteG1iogMDem+wHsSi/X4262tTdLd3e5kbIzKj/LzJ3klJcE9xOtNpbQPORc+gTo42Odvx+N5IZ8UycsL+iKZ7FFtnqf7cXBQ3y88j1gs/JHDfaVSg7SfONmO7OA4nqfvZR+dcy75UNhxCgrKlT0wEPgqHycnR/vu4KB+ZvB+od+KiLS27mr2PG9yxhPLzKjHotzcIi8/H68rOC9viMYi+m4/nfdk8KN26teU6HtUmBuMY+h/IiL9Ke1HRXT/h+B+8375nBJwX0roepL0XAyQjdcbp6ufNKTHVv4u0kO+HsstUXZfshvOl8bLRKGy+8mP8OoH4RwG+zoklezJylgkIpKTKPASMDbggUup34foniXpR6ES/KGD/Ii/6+UWB/uhvknQM1lAfuTBvqJ8Ow73pZjG10E6pwEaN9R+6Eil5EeDNJbhObIf9eXqd4FYMrj+AhqLuxL6s3F65pAUnENvf6cMDPal+dF4XphqRGQf2PtFZHnYFyorJ8mXvvQp3+7pWeRv5+YeVJ9NJqcou6vrbGXji09eXoNqy8vbp+z586/zt/ft07OsfX1zlN3WtlLZ8XjwQ1NY+KqE0d19JuxXO1dBge77VCpf2SUla/1t7BcRkf37dyu7rKxW2bnw8PAP2Lnnvk3Ze/d+zd+OxTpVW06OttvbL1V2Mhn8lhUVvexv/8d//JmMg1H5UUlJtXzgAz/1bfzR6e4+rD7LL5ZCD3hr605/e9KkWfqT9EMwe/Yl/nZT00bV1tfXoez29j3KzoUHvKCgQsLo7z/qb/P14EubiEhOTn5Gm1/qtjW9rOza0hpl4/WiP4mInHbah5Xd0PCMvz00pF8Op0w5XdncV0PwI5KfP0m13XXX+/bK2Bj1WJSfXy7LlgVjEb6cJjsb1WeT9CNSTz/8n2je6m/fR37Eg/3Z08/zt48c2ara6gf087ec7j/+iPbQj1N9fpmyq+H+r+jUY2sjvYw00I8KXm8Z/bBdRb7N30U2FOl335JZFyv71QPrgvOlH/3SmmXK3n5ok7IT8APcDH7fvPEnGc9nBIzajxJ5pTJr8UeHPa+VR/erz3bEEspupB/+61t2+NsP0RjP3x2acb6/feDwZtVWTc/+4p4WZeP95Zfnehpj8P5f2NWk2ppp/AnzhTJ62bqU+qaF9pWEsZr9aBv5RuHh4Hd5cV+banuWPlvSslMy0RMP+vj5rfcO+5kTLvp2zt3onFvvnFvf2dkd/QXDINCH+vra3+jTMf5AQT9KJm0sMsYG+tFgMjzqYLy1GM8M0wERmQH29GP/pvA87zYRuU1EZM6cSq+wMJjdnDMnmO2Ix/Vftf39+q/2vXv/Udmtre/zt53TU46TJj1Gn32vvz00pP86PHRIz8CUle1Sdm/vXH97YOAU1cYzTqlU8BdgWZme9cL9iIhMm3a7snHfra1HVRvPSnBYB2eV+Pybmm5UdmnpGn8bw4DcJiJy6qk3K3vLlm/62yUlz/vbiYT+C2aURPoR+tDUqWd4RUXB7OO8eU/72+gTIiI9PfrvgYMH1yu7tHS6v82zKngMEZHe3lZ/m1/a2tv1xEhenvYxnDXiMAzPbA3BX3QlJdrfMBQmIlJYWJXxu62t2g9q6K80DkFi6IxnMHlf+TCbwbNE8+frEPBpp+kflI0bF/vbUSHTUTDqsai8fLaH/VsDf4m2ttar7/X0HFF28xEdoVkFflRFzyr7FfpRkmZ+ltMsC89O1YEfFVKYhWcrcJZhNflRFfkg/mUtokM6S7v1tW+KmCFF1tIz9LaeZmXjDMWMEi2VmDfvcmWfc84Hlb1u3T3+dhX04/rx+dSo/aisbIaHs7U1NcGE1Jaml9T3eNa3vFk/K6tKgv1soGe7kULdi2G2Z2WHjqgkKPyVpJn1WhhHykg+wLOP6CtriqdlbBsOPI+FNGZuzi/P+FkR7fsbCir1Z6kf0X/vrNC/s1cs+oiyp0zVz8lTT/7M3+6A2VMOgR5nPN61TkTmOedmO+dyReSPReSBcezPODkxPzLGi/mQMRGYHxmhjHmGyfO8QefcX4nIoyISF5EfeZ4XLvIxDML8yBgv5kPGRGB+ZEQxnpCceJ73sIg8PPJvxGRoKJii7u4Opmz7+uarT+bmauElh48qK3/tb5eWPqPacnK08AtDZwcPfka1FRToSbbWVj3l7Fw7bOtpYxabFxQEoYuioldUWzKppyAPH75O2WVlT/jbCZoWHSLhJYcyJk0KpnYnT76TjqtDMdg3kyY9otpmztQCuY4OPYWOq+T6+4OwDYYix8Jo/MjzhlSIa//+K/1tFltX0PRsjISTGA7D8JxIeh+XlQVCyYMHeRXmJGV3dWnfxZVjvF+28d5zOCdFq6i6SITJIbpM+xVJX6FWWTl/2G2R9DAihoC5j+NxLS5vbbtM2QMD2/1tXo03HkY7Fg0NJZWo/giE2VpAfCsiUl6uF4Zw2GIWhKmKKWzBizCmTVvibx8+pMeItLALCViV6JvCdSwKRruQVmsxLBJuAL8rpFVHWiiRvpoLRcMXUWiX/QjhEDgveHjtNS1s7oQw3DY4/z7q79EyWj/qG0rKNjhXdzAQsu+lUNls8iPuu20wjkwlP5pC14Uh5A4KEfMqudp+vZgAj8vhOj4nDNmxn7BgnIXdGHLlxQPsv2yjD+YUT1Vt/ExhaDBZOkO18aKTTZv0QgsMi1fDfrZn8CPL9G0YhmEYhhGBvTAZhmEYhmFEYC9MhmEYhmEYEYxLwzRaPC9HBgcDLU9bW7AMvG6uTgUQc/rU+vtfVHZ39xn+dm+v1ly0tHxA2ZgkEhNEiqQvpS8t1XF3nVFcx2i7IFGliIgHGboPH36/aqus/L2yOWFmS0uQGBGzaL9+XP1ey8tTURcWj+t4Nfa3iMju3f/pb8+e/beq7dAhrdE6evSvlI3JQlEbxVnYTyxOpVVoawuSek6deob65MyZX1b20aMfUzbGwjlVwBHSBWASTNYKsT6INU2Y3DFsmbmI1inxObE+poeWaaOmic+pn3QMnJ0c+5S1UagZE9F9zikvDh6cpezBQZ1wEPfF55hNnNN+1AnJKllPs2CBfpb5niF8zw52aJ3ja5BkcAot92dWUaoA1GqwjqqOkl6GJSfclKeXaXNiwzpIxMqJDKsHdP6qWkp1gfqZBvKjevrsJ0Ertp76YseOB5XNOs7dOAZOXHqKUeNEXzPqtKZRMtG5c9+t7C10z6rg2X+V/IiTjzoYnzjJJWvWlua0ZGzvKNDpRVinhLDPsc2pA9BXttGYyLqqai/zcTd1HVI2a6UwmeqfU19U7X9e2aylKkvLdR6OzTAZhmEYhmFEYC9MhmEYhmEYEdgLk2EYhmEYRgRZ1TANDRVJZ2dQfFIVoNXVF1R
"text/plain": [
"<Figure size 720x720 with 16 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# High res simulation\n",
"figure(figsize=[10,10])\n",
"for i in range(16):\n",
" subplot(4,4,i+1)\n",
" imshow(cic_paint(jnp.zeros(mesh_shape), poss[::2][i]).sum(axis=0), cmap='gist_stern', vmin=0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkwAAAJBCAYAAAC0+uodAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAD7tUlEQVR4nOy9e5xddXnv/3z3nj33+0wukyvJBBIugRAgo0KVqFCEWqgFT6EqPeqh2tbLrz3nJ/XY36mtPcVWrcfa2kPFilWwYhVQUIMa1KAGQrgk5D4JTC4zk8xM5pK57pn9/f2RyVqf57NnrzW37EDyvF8vXqwn373X5bue9d1rvs/n+zzOey+GYRiGYRhGbhJn+gQMwzAMwzBe7dgLk2EYhmEYRgz2wmQYhmEYhhGDvTAZhmEYhmHEYC9MhmEYhmEYMdgLk2EYhmEYRgwzemFyzt3gnNvtnNvnnLt7tk7KOLcwPzJmivmQMRuYHxlRuOnmYXLOJUVkj4hcJyKHROQZEbnde79j9k7PONsxPzJmivmQMRuYHxlxFMzgu+tEZJ/3fr+IiHPumyJys4jkdK6KijJfV1cd2JlMSbDtXFp91nt9aplMGbWn4Lsjqo3t4uIDwfbISAPtp0jZo6OVtK9MsJ1IDEoUmUwpbOsX0QTN5Xmv/yGZPAHfLVFtIyPD9NlC2ncS9quPW1a2i/a1EKwx1YbXKiIyNqb7Au9JIjEQbHd1tUl/f7eT6TElPyourvYVFQvgnDKwPTbRV3IyOhr2azKZUm3O6ftTWFgRbKfTA6qNjzs2pv1PJOyaRCL6kcN9ZTKjdE5Jsh3Z4TlnMvqcRka176bIh6KOU1xcrex0uh/OV/tbQQE/T0PKxs+j34qIdHXt6/Dez8l5YrmZ8lhUWFjmi4tr8MyCrTHqd/6Tckx0v9eOhX7UTf3q6bOlheE4lk7rezLi9fNXFnEeafLPUbKT8OmysXTkZ9nG4yTp6qtH9Vg0FOHPA9SWLNRj+CBcfwFde0FKj4H99Ezh5zPQx6PDvTKWHsjLWCQikkyV+FRRVWhDf5WO6fuXoedVX7FINVxjL/kRfzaTCn9rxugZK6B7VkJ+hIwK+4I+R2wty8T4keTudvajCvJJ9mekn/xoCK5dRCQJ119IfjRIfpSh46bw9wOufWC4T0ZGh7IuaCYvTAtF5CDYh0SkKeoLdXXV8hd/8cfhSQ1cHGwXFh5Rn02n5yr7xIkrlI0vPkVFLaqtqOigsi+44D3B9sGDepZ1aGi5so8fv17ZyWT441ha+pJE0d9/GexX35iSEnpYMsXKrqjYHGxjv4iIHDq0X9lVVUuUXVhYHmx7cpirrnqDsl955VPBdiLRp9oKCrTd3X2dstPp8LesrOyFYPsf/uG9MgOm5EcVFQvkHe/4dzin8P4MDfWoz/IPMr+AHD8e9mtl5SLVxi+ly5e/Ndg+cuRZ1TY01K3s3l7tf7ivkpJa1cYvZsPDvcH2wECHasP7PNE5pmAgGRk5odr2t7+o7Aa6XqSoSL8oX3jhO5R98OAvg232t5qaRmV3dOxUNt4Dvp4HH/ztV3KeVDRTHouKi2ukqelDgT0GA2nPwDH1WR7Me6jf7+jcG2w/XH2e/i754NqF64LtY8f07/B+ehFv6j+q9wU/SK308tFRoMeTKvjx5f3wZ9nG41TRj+Qt3S8rexe9TCNbS+v1OS16nbK3Hd0ebNfTj/6cuZco+9fd2jXw8wPwx07b8/fnPJ9JMGU/ShVVyaLL3h3Y2O9r+7UfDdAfZQP0hwn27YaqxZGfHW64PNju6tyj2rgvVw905jr9WF8ohT/gpupHSBW9aF97olXZbQX6xQbZXD5P2XvmXarsCnj+lsAfcyIi2+auVvZAf7uyG0bCz+OzumnXwxOey2kXfTvn7nLObXHObenr64//gmEQ6ENDQ8fP9OkYr1HQj9JpG4uM6YF+NEYvucbZzUxmmA6LCL4GLxr/N4X3/l4RuVdEZPnyOl9aGv5VtXx5ONuRTOq/NoeH9WzBK6/8tbK7ut4ebDunZw6qq5+gz/5WsD02VqHa2tv1DExV1T5lDw6uCLZHRhaoNp5xymTCvzyrqvSsF+5HRGT+/C8rG/fd1dWr2nhWgsM6+Fc+n39b213KrqzcFGxjGJDbREQuvPBjyt6x43PBdkXFr4PtVCr3XzCTINaP0Ifmz79MheTmzAmvIZ2+UO24v1//RdfaulXZOKvEsx3l5fOVfeJEW7A9OKivt6dHzygVF1cpe3CwK9jmcF01zUhgO58DzxqVlelZWAx/4fmKiMylz6ZoWhtn6ngGs5tmFTBEV0qzCEuX6mdx9Wr6i2/blcH2VEOoEUx5LKqpWe7xOrGv+R719h5S9vP0Fz3OBsyne8+zdTgbOUgziE10f9MU4miAsCr/Jc0zWw1wPzfTvW+gH/msWQWYBVxFs23baCzimY8UhF54hukNw3pcw+M2kq8vXfpGZV98sW7fBTMA6OvHY0LeMUzZj6qqlvgLqpYGbfhstNE96qcZp6rjOnKwCfqAZ+6ayY+uBF9ZS37EM0wMtrMfbaxYkPOzfD/5OOyvCB9nd8TMJO9rmwqdixTQ89kDMoCHy3RE/y3nvUnZtbX6d3j79m8G28dgnMsVIpzJDNMzInK+c26Zc65QRH5PRB6dwf6McxPzI2OmmA8Zs4H5kRHJtF/Hvfejzrk/EZEfiUhSRL7ivY8W+RgGYX5kzBTzIWM2MD8y4pjR/KX3/nEReXzy30jI2FgY+ujvD6eKh4YuUJ8sLNSiMA4f1dV9N9iurHxKtRUUaJ0Lhs6OHPmIaisp0ZNsXV0syu2Gbb3CjsXmJSVhOKysTIts02k9rXj06HuUXVX1ZLDN4RIOEbBQuLo6XAk3Z84DdFw9RYl9U139Q9W2dOk6Zff06KleXCU3PByGMzAUOR2m4keZzKj0g/iwsDAUGff16dlzDncVsDAWQjIs+ubPYhiUhc6lpXXK5hBOQQGuBo2e1MV7z37AcNittDS813wcvh4Gp6p52ppXc6Fwm/utpOQxZXd2/o6yh4b0czFbTHUsymTS0td3BOwwPMghVw5RLqf7UgBhKg6Tcvgc7wOvImsl4etq0uuh2Dzled2UBsNdpbQyiI+7hEKBzbAitIrGnla6dhYyY+joDWVarJu93jCEZQe8urKlRcsFenrCsXc7CIoHZxjmnaofDWbS8jyMO6thxeRWElu/rnqpsjvA/0R0GG4RCZ3r6Z7h2MCLEJhGWgwT5UccVsOQa1YIjsYYbm8pqsjZluVHdH0tIJHYSmG2FC3eWQ/3v7FCS25QDiGiQ7kiIt0QFm2Ac2zO8XxZpm/DMAzDMIwY7IXJMAzDMAwjBnthMgzDMAzDiGFGGqap4n2BjI6GWp7jx8PUAI0rdCqAhNOnNjyskwX294fJqwYHtf6JdROYJBITRIpkL6WvrNR6BZ1RXGsBTkCiShERDxm6jx79bdVWV/dTZXPCzM7ONwfbmEX75HH1ey1nmkZdWDKpk09if4uI7N//f4LtZcv+u2prb9card7eP1E2JgtFbRRnYT+9OKUDOX68OdieS8nuFi/+rLJ7e39X2ajFGSC9QVfXXmVj2gFeSs/6IMwKLiIyBroGTl/AMXbU0gwP63tZTroGTmyJGi4+p2Fa0s2gj8WlM+jqCrV6nAX8hReW0XGfVjb6LifePJNgfi/W0yxYcJWy+ymBHya95CSm+0in0pMIx5M11M9jpDXaQMkpMckeJ67kVAEp8KMeyr7OWqlVdM6oneLl7Y3kR6x/Qk1MC+23eUT78/tBP7KDtF/79/9Y2TzmbadEiGeKhPcqDQM+K6tJp9PQsFbZ/OzPB73YS9R3TSd0wkXUjm0iP2LdGWucsD0r4SklKsXElS00dvF32Y/QV+L8KEtLBeMR75d9HRN+fqRyoWqrgiS7Itl9UwXPCR4zl9rOZpgMwzAMwzBisBcmwzAMwzCMGOyFyTAMwzAMI4a8apjGxsqkry8swKgK0OqKHiovkUh2zqPm5n/K2dbff5GyM5lQU8L5nOrrdd6ikhJd9oDzJanvQi4oEZH58++FY2qdAGu
"text/plain": [
"<Figure size 720x720 with 16 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# JaxPM without correction\n",
"figure(figsize=[10,10])\n",
"for i in range(16):\n",
" subplot(4,4,i+1)\n",
" imshow(cic_paint(jnp.zeros(mesh_shape), resi[0][::2][i]).sum(axis=0), cmap='gist_stern',\n",
" vmax=cic_paint(jnp.zeros(mesh_shape), poss[::2][i]).sum(axis=0).max(),vmin=0)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/local/home/flanusse/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py:488: ComplexWarning: Casting complex values to real discards the imaginary part\n",
" return _convert_element_type(operand, new_dtype, weak_type=False)\n"
]
}
],
"source": [
"k, pk_ref = power_spectrum(\n",
" compensate_cic(cic_paint(jnp.zeros(mesh_shape), poss[-1])),\n",
" boxsize=np.array([25.] * 3),\n",
" kmin=np.pi / 25.,\n",
" dk=2 * np.pi / 25.)\n",
"\n",
"k, pk_i = power_spectrum(\n",
" compensate_cic(cic_paint(jnp.zeros(mesh_shape), resi[0][-1])),\n",
" boxsize=np.array([25.] * 3),\n",
" kmin=np.pi / 25.,\n",
" dk=2 * np.pi / 25.)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, '$P(k)$')"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEQCAYAAABbfbiFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyKUlEQVR4nO3dd3yN5//H8dd1MkUGiS2pWTPLCGrX1tpVFDVqVin9dfuWb6tTq3YVrVGtGrVK6ZfWKDETxN4EsRMSMbKv3x8nUsTIOMmdk3yej8d5JOfc933dnxNH3rnucV1Ka40QQghxP5PRBQghhMh5JByEEEKkIuEghBAiFQkHIYQQqUg4CCGESEXCQQghRCq2RhdgKYUKFdKlS5c2ugwhhLAqu3fvDtdaF3749VwTDqVLlyY4ONjoMoQQwqoopc4+6nU5rCSEECIVCQchhBCpSDgIIYRIJdeccxDCGsTHxxMWFkZMTIzRpYg8xtHREU9PT+zs7NK0voSDENkoLCwMFxcXSpcujVLK6HJEHqG1JiIigrCwMMqUKZOmbaz+sJJSqq1SamZUVJTRpQjxVDExMXh4eEgwiGyllMLDwyNdPVarDwet9Sqt9UA3N7cMbR8afpvYhEQLVyXE40kwCCOk93Nn9eGQGUlJmm0/DGf6l2+zcNsx4hKSjC5JiCynlOLtt99OeT5u3Dg+/vjjR67r7Oyc4f2Ehobi7e2d4e2FsfJ0OCiSaOVxleGJc2i8tiWTvvqA33aeIj5RQkLkXg4ODixbtozw8HCjSxE5WN4OB5MN7gNXonuvIl/hsrybMIM6q1sybuxolgaFkiAhIXIhW1tbBg4cyIQJE9K0/ltvvUXVqlVp2rQp165dAyAkJIQ6derg6+tLx44duXHjBgC7d+/Gz88PPz8/vvvuu5Q2GjZsSEhISMrz+vXrs2/fPsu9KWFxcrUSoMo0xO2N9eiT63Fb818+vDGFU6uW8MXf3fFt2Ze2/p7YmOQ4sbCsT1Yd4vDFmxZts0oJV/7btupT13vjjTfw9fXlvffee+J6t2/fpmbNmkyYMIExY8bwySefMHXqVHr16sWUKVNo1KgRo0eP5pNPPmHixIn07duXqVOn0rBhQ959992Udvr168fcuXOZOHEix48fJyYmBj8/v0y/X5F18nTP4QFKoZ5thuubgeiu8ylS0IXRsd9SaUVrPvl6LKtCLpCUJPNti9zB1dWVXr16MXny5CeuZzKZ6Nq1KwA9e/YkMDCQqKgoIiMjadSoEQC9e/dm8+bNREZGEhkZScOGDQF49dVXU9p5+eWX+eOPP4iPj2f27Nn06dMna96YsBjpOTxMKVTlNrhUfIGkQ8vxXDuGMbe+ZN+yRYxa14v6rbrS0rs4JulJiExKy1/4WWnEiBFUr16dvn37ApCYmEiNGjUAaNeuHWPGjEm1TUavtHJycqJ58+b8/vvvLF68mN27d2e8cJEtpOfwOCYTJp+XcH5rN0ntvuNZ51g+v/MxhZa054Px01h36DJaS09CWC93d3e6dOnCrFmzALCxsSEkJISQkJCUYEhKSmLJkiUA/Prrr9SvXx83NzcKFizIli1bAPj5559p1KgRBQoUoECBAgQGBgIwf/78B/bXv39/3nzzTQICAihYsGB2vU2RQRIOT2Nji6l6T5z+L4SkF8bj7RTJ17dG4rSwE+9OmMWGo1ckJITVevvtt5941VL+/PnZtWsX3t7ebNiwgdGjRwPw008/8e677+Lr60tISEjK63PmzOGNN97A398/1f+LGjVq4OrqmtJTETmbyi2/2GrWrKmzZT6H+LskBs0mftM4HOOusz6xGqs8XqPjC61p+GwhucFJPNGRI0eoXLmy0WUY4uLFizRu3JijR49iMsnfpUZ41OdPKbVba13z4XXlXyi97PJhU/cNHN8+QGKT0dR3PMXEyGHc+rk7I6YuZOvJcOlJCPGQefPmUbt2bT7//HMJBishPYfMiokicetUErd9h23iHVYk1mNjsdfo0boxdcp6ZH89IkfLyz0HYTzpOWQnRzdsmv4H+/87QNJzw2hnF8yEa/05M7sfw75fSXDodaMrFEKIdJNwsJT8Hti2/BTbt/ZDQH+62AXy7ZW+HPxxEENnrmHPuRtGVyiEEGkm4WBpLkWxffEbbEaEYPLvQS+7DXxzsTe7Zg5l2I9/sT8s0ugKhRDiqSQcsoqbJ7YdJmMaFoStdwcG2a7mq7CebJw+gmGzN3Hoosw/IYTIuSQcspp7Wew6/4AasgP7Si0Zbrucz891Z/V37zD8p0COXrbs2DpCpEVGhuIODQ0lX758+Pv7U6VKFQYPHkxSUhKhoaEopfjoo49S1g0PD8fOzo6hQ4emez/BwcG8+eabAGzatIlt27alLOvTp0/KTXmZNXHiRO7cuWORtizp4bpeeOEFIiMjs70OCYfsUqQSdt3mwaAt5Ctfn/fsFjP69Cv8NuUDRvyynRNXoo2uUIinKleuHCEhIezfv5/Dhw+zYsUKAMqUKcPq1atT1vvtt9+oWjVjw4PUrFkzZcynh8PBkrIiHBISEp74PC0ermvNmjUUKFAgs6Wlm9WHg9VNE1rcF7uei6H/elzLVGeU3Xw+PPEKP00exf8t2MWpa7eMrlDkEbdu3aJp06ZUr14dHx8ffv/9dwCCgoLw9fUlJiaG27dvU7VqVQ4ePPjAtra2ttStW5eTJ08C5rGTKleuzL3LyRctWkSXLl0euV8fHx8iIyPRWuPh4cG8efMA6NWrF3/99RebNm2iTZs2hIaGMn36dCZMmIC/v3/KcB2bN2+mbt26lC1bNqUXobXm3XffxdvbGx8fHxYtWgSQ0tY9Q4cOZe7cuUyePJmLFy/y/PPP8/zzz6eqMSgoiLp16+Ln50etWrWIjo4mJiaGvn374uPjQ7Vq1di4cSMAc+fOpV27djRp0oSmTZumen779m1ee+01atWqRbVq1VJ+zomJibzzzjt4e3vj6+vLlClTHllX6dKlU+5iHz9+PN7e3nh7ezNx4kTA3KOrXLkyAwYMoGrVqrRo0YK7d++m+XPwOFY/8J7WehWwqmbNmgOMriVdPGti1+d3CA3E/a8xfHZhDheO/sGkgx1J8u3KsGaVKeWR3+gqRVb68wO4fMCybRbzgdZfpWlVR0dHli9fjqurK+Hh4dSpU4d27doREBBAu3bt+Oijj7h79y49e/bE29ub0NDQlG3v3LnD+vXrHxicr1u3bixcuJCiRYtiY2NDiRIluHjxYqr91qtXj61bt1KqVCnKli3Lli1b6NWrF9u3b+f7778nKCgIMP9SHDx4MM7OzrzzzjsAzJo1i0uXLhEYGMjRo0dp164dnTt3ZtmyZYSEhLBv3z7Cw8MJCAhIGR32Ud58803Gjx/Pxo0bKVSo0APL4uLi6Nq1K4sWLSIgIICbN2+SL18+Jk2ahFKKAwcOcPToUVq0aMHx48cB2LNnD/v378fd3Z25c+c+8HzkyJE0adKE2bNnExkZSa1atWjWrBnz5s0jNDSUkJAQbG1tuX79Ou7u7o+ta/fu3cyZM4edO3eitaZ27do0atSIggULcuLECRYsWMAPP/xAly5dWLp0KT179kzT5+BxrD4crF7p+tj1Xwun1lPkr0/5+spMzhxeyYQDnXHw68zQphXxcncyukqRC2mtGTlyJJs3b8ZkMnHhwgWuXLlCsWLFGD16NAEBATg6Oj4wrPepU6fw9/dHKUX79u1p3bp1Smi0atWKUaNGUbRo0ZRhvh+lQYMGbN68mVKlSvH6668zc+ZMLly4QMGCBcmf/+l/EHXo0AGTyUSVKlW4cuUKAIGBgbzyyivY2NhQtGhRGjVqRFBQEK6urun+uRw7dozixYsTEBAAkNJGYGAgw4YNA6BSpUqUKlUqJRyaN2+Ou7t7Shv3P1+3bh0rV65k3LhxAMTExHDu3Dn+/vtvBg8ejK2t+dfw/ds/SmBgIB07dkz5GXXq1IktW7bQrl07ypQpg7+/P2Aew+r+IM8oCYecQCko3wy7ck3h2Bq8/v6UieFTOXbwd74IeZkC1TsytOmzlCyQz+hKhSWl8S/8rDJ//ny
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"loglog(k,pk_ref, label='N-body')\n",
"loglog(k,pk_i, label='JaxPM without correction')\n",
"legend()\n",
"plt.xlabel(r\"$k$ [$h \\ \\mathrm{Mpc}^{-1}$]\")\n",
"plt.ylabel(r\"$P(k)$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Implementation of a plain Hamiltonian Graph Neural Network"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# We start by computing a neighboorhood to compute PP interactions \n",
"# Could we see if we could get a graph on this\n",
"from sklearn.neighbors import radius_neighbors_graph\n",
"\n",
"A = radius_neighbors_graph(poss[-1], 1.5, mode='distance').tocoo()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import jraph\n",
"from frozendict import frozendict\n",
"\n",
"# Initialize graph structure\n",
"static_graph = jraph.GraphsTuple(\n",
" n_node=np.asarray([A.shape[0]]),\n",
" n_edge=np.asarray([len(A.data)]),\n",
" nodes={\n",
" 'position': poss[0],\n",
" 'momentum': vels[0],\n",
" },\n",
" senders=A.row,\n",
" receivers=A.col,\n",
" edges={'grav_potential':jnp.zeros(len(A.data))},\n",
" globals={}\n",
")\n",
"\n",
"# Tell tree_util how to navigate frozendicts.\n",
"jax.tree_util.register_pytree_node(\n",
" frozendict,\n",
" flatten_func=lambda s: (tuple(s.values()), tuple(s.keys())),\n",
" unflatten_func=lambda k, xs: frozendict(zip(k, xs)))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<32768x32768 sparse matrix of type '<class 'numpy.float64'>'\n",
"\twith 8060958 stored elements in COOrdinate format>"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"A"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAD4CAYAAAAZ1BptAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAATN0lEQVR4nO3df4xl5X3f8fcnEBy7iQ02G2rtsl2ibNqsaSPjCZC4yg8TwYIrL1UdC5TUa2vllWJw80ttoJVKZScRVts4IbKJNmbDYqVeExKVVQwlK4xltemuWUIKBkqY4h/sBhviBdyWxs7a3/5xH+Lr8TyzM3Nn770z835JV3POc55z73dmZ+dzn/Occ26qCkmS5vMdky5AkjS9DAlJUpchIUnqMiQkSV2GhCSp6/RJF7DSzj777NqyZcuky5CkVeWBBx74q6raMLd9zYXEli1bOHLkyKTLkKRVJcnn52v3cJMkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKlrzV1xremy5bqPz9v+uRvfPOZKJC2HIaFl6wXAKPsaHtJ0MSQ0VRYbPIaJNB6GhFal4TAxMKRTx5DQkoxyiEnS6mNIaNVzfkM6dQwJrQsenpKWx5DQvNbCYaW18D1Ik3bSi+mS7E3yTJLPDLW9OsnBJE+0r2e19iS5KclskoeSXDC0z87W/4kkO4fa35Dk4bbPTUmy0GtIo9py3cf/9iFpYYu54vpWYPuctuuAe6tqK3BvWwe4HNjaHruBm2HwBx+4AbgIuBC4YeiP/s3Au4b2236S15AkjclJQ6KqPgUcn9O8A9jXlvcBVw6131YDh4Azk7wWuAw4WFXHq+o54CCwvW17ZVUdqqoCbpvzXPO9hrRiHFVIC1vunMQ5VfV0W/4icE5b3gg8NdTvaGtbqP3oPO0Lvca3SbKbwciFzZs3L/V7UeMfSklzjTxxXVWVpFaimOW+RlXtAfYAzMzMnNJa1hqD4Zs8A0r6dsu9C+yX2qEi2tdnWvsx4Nyhfpta20Ltm+ZpX+g1JEljstyRxAFgJ3Bj+3rnUPu1SfYzmKR+oaqeTnIP8OtDk9WXAtdX1fEkX0lyMXAYeDvw2yd5DY3I0cPJOaqQBk4aEkk+CvwEcHaSowzOUroRuD3JLuDzwNta97uAK4BZ4EXgnQAtDN4H3N/6vbeqXpoMfzeDM6heDtzdHizwGpKkMTlpSFTV1Z1Nl8zTt4BrOs+zF9g7T/sR4Px52r8832tIksbHT6aTJHUZEpKkLu/dtE44Wb18TmJrPTMkpCUwMLTeGBLSMhkYWg+ck5AkdRkSkqQuDzetYU5Wj4+HnrRWOZKQJHU5kpBWmKMKrSWOJCRJXYaEJKnLw01rjJPV08VDT1rtHElIkrocSUhj4qhCq5EhIU2AgaHVwsNNkqQuRxJrgJPVq5ujCk0zRxKSpC5DQpLU5eEmaYr0Dh16GEqT4khCktTlSEJaBRY6OcFRhk4lQ2KV8owmSeNgSEirnKfQ6lQyJKQ1xMDQSnPiWpLU5UhiFXEeQtK4OZKQJHU5kpDWKC/M00oYaSSR5BeTPJLkM0k+muS7kpyX5HCS2SQfS3JG6/uytj7btm8Zep7rW/vjSS4bat/e2maTXDdKrZKkpVt2SCTZCPwLYKaqzgdOA64C3g98oKq+H3gO2NV22QU819o/0PqRZFvb73XAduBDSU5LchrwQeByYBtwdesrSRqTUeckTgdenuR04BXA08CbgDva9n3AlW15R1unbb8kSVr7/qr6alV9FpgFLmyP2ap6sqq+BuxvfSVJY7LsOYmqOpbkPwBfAP4f8CfAA8DzVXWidTsKbGzLG4Gn2r4nkrwAvKa1Hxp66uF9nprTftF8tSTZDewG2Lx583K/JWld8FoKLcWyQyLJWQze2Z8HPA/8AYPDRWNXVXuAPQAzMzM1iRpOFU971alkYOhkRjnc9FPAZ6vq2ar6G+CPgDcCZ7bDTwCbgGNt+RhwLkDb/irgy8Ptc/bptUuSxmSUkPgCcHGSV7S5hUuAR4H7gLe2PjuBO9vygbZO2/6JqqrWflU7++k8YCvwaeB+YGs7W+oMBpPbB0aoV5K0RKPMSRxOcgfwZ8AJ4EEGh3w+DuxP8qut7Za2yy3AR5LMAscZ/NGnqh5JcjuDgDkBXFNVXwdIci1wD4Mzp/ZW1SPLrVfSwjz0pPlk8GZ+7ZiZmakjR45MuowV45yEJs3AWB+SPFBVM3PbveJa0sgchaxdhoSkBTmaXd+8wZ8kqcuRxBTynZukaWFISFpRi3mT47zF6mFISJoaToBPH0NC0tgZBquHISFponqHpwyS6eDZTZKkLkcSklYtRxunnrflmAKe8iqtPENjabwth6R1xVHGyjAkJK15BsbyGRKS1hUv9lsaQ0KSFrDeRyGGhCTNsZhrNxZjOFR6YbOYEJrk6MeQkKRTZKlhM41nOhoSEzKNvwySNJdXXEuSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLU5cV0Y+QFdJJWG0cSkqQuQ0KS1GVISJK6RgqJJGcmuSPJ/0zyWJIfSfLqJAeTPNG+ntX6JslNSWaTPJTkgqHn2dn6P5Fk51D7G5I83Pa5KUlGqVeStDSjjiR+C/gvVfUPgB8CHgOuA+6tqq3AvW0d4HJga3vsBm4GSPJq4AbgIuBC4IaXgqX1edfQfttHrFeStATLDokkrwJ+DLgFoKq+VlXPAzuAfa3bPuDKtrwDuK0GDgFnJnktcBlwsKqOV9VzwEFge9v2yqo6VFUF3Db0XJKkMRhlJHEe8Czwe0keTPLhJH8HOKeqnm59vgic05Y3Ak8N7X+0tS3UfnSe9m+TZHeSI0mOPPvssyN8S5KkYaOExOnABcDNVfV64P/yzUNLALQRQI3wGotSVXuqaqaqZjZs2HCqX06S1o1RQuIocLSqDrf1OxiExpfaoSLa12fa9mPAuUP7b2ptC7VvmqddkjQmy77iuqq+mOSpJH+/qh4HLgEebY+dwI3t651tlwPAtUn2M5ikfqGqnk5yD/DrQ5PVlwLXV9XxJF9JcjFwGHg78NvLrXdSvMpa0mo26m053gP8fpIzgCeBdzIYndyeZBfweeBtre9dwBXALPBi60sLg/cB97d+762q42353cCtwMuBu9tDkjQmI4VEVf05MDPPpkvm6VvANZ3n2Qvsnaf9CHD+KDVKkpbPK64lSV2GhCSpy5CQJHUZEpKkLkNCktRlSEiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1jXqDP83DO79KWiscSUiSugwJSVKXISFJ6jIkJEldhoQkqcuQkCR1GRKSpC5DQpLUZUhIkroMCUlSlyEhSeoyJCRJXYaEJKnLu8CuEO/8KmktciQhSeoyJCRJXYaEJKnLkJAkdY0cEklOS/Jgkj9u6+clOZxkNsnHkpzR2l/W1mfb9i1Dz3F9a388yWVD7dtb22yS60atVZK0NCsxkvh54LGh9fcDH6iq7weeA3a19l3Ac639A60fSbYBVwGvA7YDH2rBcxrwQeByYBtwdesrSRqTkUIiySbgzcCH23qANwF3tC77gCvb8o62Ttt+Seu/A9hfVV+tqs8Cs8CF7TFbVU9W1deA/a2vJGlMRh1J/Cbwr4BvtPXXAM9X1Ym2fhTY2JY3Ak8BtO0vtP5/2z5nn167JGlMlh0SSf4J8ExVPbCC9Sy3lt1JjiQ58uyzz066HElaM0YZSbwReEuSzzE4FPQm4LeAM5O8dCX3JuBYWz4GnAvQtr8K+PJw+5x9eu3fpqr2VNVMVc1s2LBhhG9JkjRs2SFRVddX1aaq2sJg4vkTVfUzwH3AW1u3ncCdbflAW6dt/0RVVWu/qp39dB6wFfg0cD+wtZ0tdUZ7jQPLrVeStHSn4t5NvwLsT/KrwIPALa39FuAjSWaB4wz+6FNVjyS5HXgUOAFcU1VfB0hyLXAPcBqwt6oeOQX1SpI6ViQkquqTwCfb8pMMzkya2+evgZ/u7P9rwK/N034XcNdK1ChJWjrvAjsC7/wqaa3zthySpC5DQpLUZUhIkroMCUl
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"hist(A.data,100);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: we are not using a fully connected graph here, so we won't get the long range interactions, only the small scales"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"def update_edge_fn(edges, senders, receivers, globals_):\n",
" del globals_\n",
" # Models particle-particle interactions contribution to grav potential\n",
" distance = jnp.linalg.norm(senders[\"position\"] - receivers[\"position\"], axis=1)\n",
" grav_potential_per_edge = - 1. /(distance + 0.1) # The 0.1 is to soften the forces on small scales\n",
" return frozendict({\"grav_potential\": grav_potential_per_edge})\n",
"\n",
"def update_node_fn(nodes, sent_edges, received_edges, globals_):\n",
" del sent_edges, received_edges, globals_\n",
" \n",
" # Computes momentum\n",
" momentum_norm = jnp.linalg.norm(nodes[\"momentum\"],axis=1)\n",
" kinetic_energy = 0.5 * momentum_norm ** 2 \n",
" return frozendict({\"kinetic_energy\": kinetic_energy})\n",
"\n",
"def update_global_fn(nodes, edges, globals_):\n",
" del globals_\n",
" # At this point we will receive node and edge features aggregated (summed)\n",
" # for all nodes and edges in each graph.\n",
" hamiltonian_per_graph = nodes[\"kinetic_energy\"] + edges[\"grav_potential\"]\n",
" return frozendict({\"hamiltonian\": hamiltonian_per_graph})\n",
"\n",
"# Create the Hamiltonian Graph\n",
"hamiltonian_gnn = jraph.GraphNetwork(\n",
" update_edge_fn=update_edge_fn,\n",
" update_node_fn=update_node_fn,\n",
" update_global_fn=update_global_fn)\n",
"\n",
"# Function that computes the hamiltonian for input position and momentum\n",
"def hamiltonian_from_state_fn(position, momentum):\n",
" # Update variables in graph\n",
" graph = static_graph._replace(nodes={\n",
" 'position': position,\n",
" 'momentum': momentum})\n",
" output_graph = hamiltonian_gnn(graph)\n",
" return output_graph.globals[\"hamiltonian\"].sum()\n",
"\n",
"# Computes the derivatives of the Hamiltonian\n",
"hamiltonian_gradients_fn = jax.grad(hamiltonian_from_state_fn, argnums=[0, 1])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def gnn_nbody_ode(state, a, cosmo):\n",
" \"\"\"\n",
" state is a tuple (position, velocities)\n",
" \"\"\"\n",
" pos, vel = state\n",
" \n",
" # Take the derivatives against position and momentum of the hamiltonian of the system\n",
" dh_dposition, dh_dmomentum = hamiltonian_gradients_fn(pos, vel)\n",
" \n",
" # Hamilton equations\n",
" dpos_da = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * dh_dmomentum\n",
" dvel_da = - 1.5 * cosmo.Omega_m / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * dh_dposition\n",
"\n",
" return dpos_da, dvel_da"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"res = odeint(gnn_nbody_ode, [poss[0], vels[0]], jnp.array(scales), cosmo, rtol=1e-4, atol=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkwAAAJBCAYAAAC0+uodAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd3hV1dLGZ53kpPcEQiihhKogRYoIKqhwFSzoFRUuihU78lmuXq9dr+K1Y0dRQQUVRVBBBSlSRDoYOoQSSgKk976/P4S9Z95wckISjl6d3/P4uIZ1yt5rz15nZ827ZoxlWaQoiqIoiqJ4xvV7H4CiKIqiKMofHX1gUhRFURRF8YI+MCmKoiiKonhBH5gURVEURVG8oA9MiqIoiqIoXtAHJkVRFEVRFC/U64HJGHOBMWabMWanMebBhjoo5a+F+pFSX9SHlIZA/UipCVPXPEzGGD8i2k5Eg4hoPxGtIqIRlmVtbrjDU/7sqB8p9UV9SGkI1I8Ub/jX4729iWinZVm7iIiMMZ8S0aVE5NG5wsNDrdjYKNuuqgq228aUi9daljy0qqpQ6Hez95aJPrSDgnbb7bKyBPicQGFXVETAZ1XZbZermGqiqiqEteWDqAvW8ixL/oOfXwF7b7DoKysrhdcGwGf7sc+V3xsauhU+qxmzKkUfP1ciospKORb8mrhcRXY7KyudCgtzDNWNE/KjIP8gKzQgzLbL/YOc46uU1z24qkLYmey1RESxFSV2O9/PLfoqjbw+UcGxdrusrED0FcD3ukhegxB2vQrhtTho/J1hlXBPGPnqSpc85oAA5x4pL5e+GlReJOxc8KHwKue7DPwNVR4YLmx3ab7dDrSkD+E4BlXJ/gD2elzeTi3Ny7AsqxGdOCc8F0W63Fa8v3PvF7DjdsE9xMeGiCjbT84ZjSqcsS50yXmrAvwo2+1cowDwhSDwV7cl70f+WfhnLvqrH3tvdKWcP8qMn7Dxmh3xd+afhPJC0RcKx1gKn8XvuSIYiwy8/9hxhcDnFsB7o2CsPHGgspyyqip8MhcREUW5/K2m7F4qYXMxHgSOXVYNflQC518G1zfd7fzWRHqZ99CfOZUwp5RX8yPnvXgNquC9eTAfZbP7qxGba4mq+zr6rz/z32IYC/we7jt4r6J/huBczc6hil2xfVXllFlVWc2P6vPA1IyI9jF7PxH1qekNsbFR9Mgjd9h2UdGpdjsg4KB4bXl5Y2EXFJwubP7gExiYKvoCA/cJu337a+32vn1ylbWkpI2ws7MHC9vPz/mhCQnZRDVRWNiVfa68cMHBcuyrquTkER6+wm7zcSEi2r9/l7AjIxOFHcAeICyYZHv1OlPYe/c+bbddrnzR5+8v7ZycQcIuL3d+y0JDN9jtl1++gerBCflRaEAYXdB+qG0fadzZbhdn7xav7VKSLezJse2FfUXmdru9MEw+SOf6y8ns792us9t79y4WfYtzpf/h5H96VCu7/UvOXtHnhp++cnbT9i9Il30uefMXRjQXdrNmve324cMbRV/btLXCng0+NDDfuf/c8JBzJEn6QaOUeXY7qUz6zMLwpsLuWJIj7ET2sBUCD1u3p8yVg1N7TnguivcPpNfiu9j2srAmznHB9RuQL+emL6LlnHHrEef3dGWonLfwIeGr5mfY7abZ8r7GsUqAh1z+WfjDhg/A/Ef0CvieVDZfEBEtAt9/q7Ez/4w7uEb09S46IuxdAfJhujO759aFxIm+9+I6Cns0u/+6F2WIvmWh8cK+NLd2rnEZnOsJcsJ+1NQvgD6K62Db24Ki7DY+8PYuPCzsaTFthc39aCv7HCKifXDNxjfpZreHwvzTpThL2OjPHPSbNH/5x3okewC5NGeP6MMH4nkwH01n98ltR+QzZwfwdTyOOPaAtTE4RvTNhe/pwXwH79WUQPlHfy8Yq0D2vfx8BuXKZ4hjnHTRtzFmjDFmtTFmdX5+ofc3KArAfagE/lJRlNrC/SgX/hJVlNrC/Si7hocR5c9HfVaYDhBRC2Y3P/pvAsuyJhLRRCKiNm1irZAQ50mzTRtntcPPTz5Bl5bmCnvv3qeEnZV1sd02RjptVNQ8eO1FdruyUv5FdOiQXIGJjNwp7OJi56+AsjL51zOuOFVVOU+rkZHySZZ/DhFRkybvCZt/dlZWnugLhidsFzzZ81UlPP709DHCjohYard5GBD7iIg6dXpA2Js3v2S3w8N/sdtudybVA69+xH0oIrSxlcaWo1uyVZUHJsi/Lu+/f4Cwz4S/anYm9LDbufBXWVe24kBElJe3327j+CdCiA7/gg8Pd8KgCRkyRDp0wOPCvuCC1nZ7xMO3iL7TYPWiUxu58lPCzi+THS8RUVi7ocIuPyJ9l68MXdn3HtHXu/1UYY/7pJ3dHtluiOi74fz/Cjsubqawp09fbbc/3TJD9FHKXKojJzwXJQZFWnw1YCnzhQsvHCbe9/bH8pzwL3a+4oR9yUHRwt5Y4YShoiBcgH9lF0H/UuaTPYrlPbeNrVwREfXr59z3L31+s+jDlay1if2EfQZbvUqokKHdOREthI2fxVfBZp1yhejrHH+asJeufttuvwh953cZKexXYKXOb9HjdpuvjhbDCsIJcsJ+1DYw3OKrP1PZqlHbsx4S71v6o4xu4NjNYau+sfCH4VomCSAi6sNWq8oh+MfnR6LqPpnCVgVxlRdXnnv2ut1uF/0g54XIChnq5ateRHLFpgmslq4MkZH3FuWeF1P+1bSXsE9tJu2iDVPs9rgW8vc8qPW5wp4MK6KX/fy83earXriCe4z6rDCtIqJ2xpjWxpgAIrqaiL6ux+cpf03Uj5T6oj6kNATqR0qN1HmFybKsCmPMnUT0AxH5EdH7lmXVLPJRFED9SKkv6kNKQ6B+pHijPiE5sixrDhHNqf07XFRZ6SxfFhY6IYaSEinIDQhIEzaGj2Jjv7LbERHLRJ+/vxT78tDZwYN3i77gYLnIlpUlw1/G5LC2FEei2Dw42AmHhYb+KvrKy+XS/OHD1wo7MnKR3XbDkmol7qzCHVxRTpinUSMZPuFCbSI5NlFR34u+li17Czs3Vwox+S650lJn6ZaHIuvCCfmRXwC52LJx796O0Ltjx7fES995Z5Kwx46Vy83hLAzVnYnHiYiuvlqGBOLjHb99553rRN+uufcJG5fa09Ic4Ww8hPo2bPhI2Nu2vWq3i1zbRV8FLNNv2fKl/N6Ow+w27pwpAQH8qSAq7cJCIFdc8TfRt3nzJcLuy4ShMSBcveoqeUxIVpYTGrSsJ0XfxMVP48trzYnORYf8g0UY6ML2Tti+det7xWvbPSHHbuLE0cLex3aZBcG4ciE+EdE98c5rFyz4RPSF7P9F2Bjy4BsEMEQzLEXKEFLZZyWALySHyPd2ge/d0LQneQLfmxIUKexiFhq8+GL5vQsXStHzoRgntNspVM5TV1whN/mcf74UUI8Y4fjr3ft+ttvpmTs8HnttOFE/OuIfRG81OsW2Q1hIKClJ/gZEnS5D6tOmnSXsn5igvh+E9Vu2PFvYPMy/deGjoi+hQIZRcfMAD7/jpoRHYHNI2ndjyRMoTL/tsHy2fLGJswkKBeLV/KhK+tFy5kejz7tK9C1b9oP8HnYfd4CNBlcP+0nYo0a9LezBg7fY7UZso0x6XrVILBFppm9FURRFURSv6AOToiiKoiiKF/SBSVEURVEUxQt1Lo1SF1q3TrAee+w62+a6pKS2MgbvMjLmuW9fS2EXFjpxS5dLajsyMy8TNk8SydMaEFXPMM61OUSYUVy+toAlqiQisliG7rIyqUOKjV0gbEyYWVjYym6XQ8wZk1GipikuzklD0KzZi6IPz+fAAUef0bq11N5gFvS8vP4e+4OCnC38L754F6Wmbq9rdt0TIiisidW86zW2fVq8cw0aNeokXvvOO1IHsXSp3BL9+usT7HYexKxbtpT6gt69HW1OVNSPom/2bDnmu3bJ/oED72F9q0XfPqa/ICKKYEnZciBRXFPQlmRlpQg7niVizIXEazvTZALCgZDgjW8n7giJGQl0DVyjtZ5txyci2uElIR3f4lxtG/1PT66xLMuzgKYBaRE
"text/plain": [
"<Figure size 720x720 with 16 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"figure(figsize=[10,10])\n",
"for i in range(16):\n",
" subplot(4,4,i+1)\n",
" imshow(cic_paint(jnp.zeros(mesh_shape), res[0][::2][i]).sum(axis=0), cmap='gist_stern',\n",
" vmax=cic_paint(jnp.zeros(mesh_shape), poss[::2][i]).sum(axis=0).max(),vmin=0)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/local/home/flanusse/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py:488: ComplexWarning: Casting complex values to real discards the imaginary part\n",
" return _convert_element_type(operand, new_dtype, weak_type=False)\n"
]
}
],
"source": [
"k, pk_c = power_spectrum(\n",
" compensate_cic(cic_paint(jnp.zeros(mesh_shape), res[0][-1])),\n",
" boxsize=np.array([25.] * 3),\n",
" kmin=np.pi / 25.,\n",
" dk=2 * np.pi / 25.)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, '$P(k)$')"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEQCAYAAABbfbiFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAA/8UlEQVR4nO3dd1xV9R/H8df3spGhKE5QnOFgqLhH7lGucmvlHlnZMttWlpktrcxVaWb9NDNnamruLQ6cuUXBgYqCyJBxv78/LpCKKOiFy4XP8/HgIdx77jmfq8ib7/ec8/kqrTVCCCHE7QyWLkAIIUTeI+EghBAiAwkHIYQQGUg4CCGEyEDCQQghRAYSDkIIITKwtXQB5lKsWDHt4+Nj6TKEEMKq7Nmz56rW2vPux/NNOPj4+LB7925LlyGEEFZFKXX2Xo/LtJIQQogMJByEEEJkIOEghBAig3xzzkEIa5SUlER4eDgJCQmWLkXkc46Ojnh5eWFnZ5el7SUchLCg8PBwXF1d8fHxQSll6XJEPqW1JjIykvDwcMqXL5+l11j9tJJSqqNSakZ0dLSlSxEi2xISEihatKgEg8hRSimKFi2arRGq1YeD1nqZ1nqou7v7Q70+9Gost5JTzFyVEFknwSByQ3a/z6w+HB6F0ajZ9sPLTBv/OvO2HSMx2WjpkoTIdUopXn/99fSvv/zySz788MN7buvi4vLQxwkNDaVGjRoP/XqRuwp0OCiMtCt6mZdTZtFsVVu++ewt/th5iqQUCQlRcDg4OLBw4UKuXr1q6VJEHlKww8Fgg8fQpeh+y3DyrMAbydOpv7wtX04Yw5/BoSRLSIgCwNbWlqFDhzJx4sQsbf/qq69SvXp1WrZsyZUrVwAICQmhfv36+Pv789RTT3H9+nUA9uzZQ0BAAAEBAXz//ffp+2jatCkhISHpXzdu3Jj9+/eb702JRyZXKwGqfFPcX1iLPrkW9xUf8Pb17zi1bAGf/tMH/7YD6BjohY1B5oVFzvpo2WGOXLhh1n1WK+3GBx2rP3C7F154AX9/f0aPHn3f7WJjYwkKCmLixImMHTuWjz76iMmTJ/Pcc8/x3Xff8fjjjzNmzBg++ugjJk2axIABA5g8eTJNmzbljTfeSN/PoEGD+Pnnn5k0aRLHjx8nISGBgICAR36/wnwK9MjhDkqhKrfCbeQWdM/fKF7ElTG3vsJ3cXs++nwCy0LOYzTKetsif3Jzc+O5557j22+/ve92BoOBnj17AvDMM8+wZcsWoqOjiYqK4vHHHwegX79+bNq0iaioKKKiomjatCkAzz77bPp+unfvzl9//UVSUhIzZ86kf//+OfPGxEOTkcPdlEJV7YDrY09gPLwIr1VjGXtzPPsX/s77q5+jcbuetK1RCoOMJISZZeU3/Jz0yiuvUKtWLQYMGABASkoKtWvXBqBTp06MHTs2w2se9korZ2dnWrduzZIlS5g/fz579ux5+MJFjpCRQ2YMBgx+XXF5dQ/GTt9T2eUW4+I+pNiCzrz19RRWH76E1jKSEPmHh4cHPXr04KeffgLAxsaGkJAQQkJC0oPBaDSyYMECAP73v//RuHFj3N3dKVKkCJs3bwZgzpw5PP744xQuXJjChQuzZcsWAH777bc7jjd48GBGjhxJnTp1KFKkSG69TZFFEg4PYmOLodYzOL8WgvGJr6nhHMXnN9/Bed7TvDHxJ9YdjZCQEPnG66+/ft+rlgoVKsSuXbuoUaMG69atY8yYMQDMnj2bN954A39/f0JCQtIfnzVrFi+88AKBgYEZ/p/Url0bNze39JGKyFtUfvnBFhQUpHNlPYekeFKCZ5K04UscE6+xNqUmy4oO5Kkn2tO0cjG5oUlky7///kvVqlUtXYZFXLhwgWbNmnH06FEMBvk9NTfc6/tNKbVHax1097byL5Jddk7YNHwBx9cPktJiDI0dTzEp6iVuzunDK5PnsfXkVRlJCPEAv/zyC/Xq1WPcuHESDHmUjBweVUI0KVsnk7Lte2xT4lic0oj1JQfSt30z6lcomvv1CKtSkEcOIvfJyCE3Obpj0/Jd7F87iLHBS3Sy283EK4M5M3MQL01dyu7Qa5auUAghsk3CwVwKFcW27cfYvnoA6gymh90WvooYwKEfh/HijBXsPXfd0hUKIUSWSTiYm2sJbJ/8AptXQjAE9uU5u3V8caEfu2a8yEs/ruFAeJSlKxRCiAeScMgp7l7YdvkWw0vB2NbowjDb5XwW/gzrp73CSzM3cPiCrD8hhMi7JBxymkcF7Lr9gBqxA3vftrxsu4hx5/qw/PtRvDx7C0cvmbeXjhAP42FacYeGhuLk5ERgYCDVqlVj+PDhGI1GQkNDUUrx3nvvpW979epV7OzsePHFFzPs5+eff8bT0zN9Pz/88MMjvZfM2NjYEBgYSI0aNejevTtxcXGA6S7vZ555Jn275ORkPD096dChQ47UYS0kHHJLcV/sev0CwzbjVKkxo+3mM+Z0b/747i1e+XU7JyJiLF2hENlWsWJFQkJCOHDgAEeOHGHx4sUAlC9fnuXLl6dv98cff1C9eubtQXr27ElISAgbNmzgnXfeISIiIkvHT05OznKtTk5OhISEcOjQIezt7Zk2bRpgurHv0KFDxMfHA7BmzRrKlCmT5f3mV1YfDla3TGgpf+yemQ+D1+JWvhbv2/3G2yd6M/vb93lt7i5OXblp6QpFAXXz5k1atmxJrVq18PPzY8mSJQAEBwfj7+9PQkICsbGxVK9enUOHDt3xWltbWxo2bMjJkycBU++kqlWrknZ5+e+//06PHj0eWEPx4sWpWLEiZ8+epX///umtOuC/0c2GDRto0qQJnTp1olq1aqSkpPDGG29Qp04d/P39mT59+gOP06RJk/RaAZ544on0MJs7dy69e/d+4D7yO6tvvKe1XgYsCwoKGmLpWrLFKwi7/ksgdAsea8byyflZnD/6F98cegqjf09ealWVckULWbpKkZtWvgWXDpp3nyX9oP1nWdrU0dGRRYsW4ebmxtWrV6lfvz6dOnWiTp06dOrUiffee4/4+HieeeYZatSoQWhoaPpr4+LiWLt27R3N+Xr16sW8efMoUaIENjY2lC5dmgsXLty3htOnT3P69GkqVap03+327t3LoUOHKF++PDNmzMDd3Z3g4GBu3bpFo0aNaNOmDeXLl7/na5OTk1m5ciXt2rW7o9axY8fSoUMHDhw4wMCBA9N7RRVUVh8OVs+nMXaDV8GptRRf8zGfR8zgzJGlTDzYDYeAbrzY8jG8PZwtXaUoALTWvPPOO2zatAmDwcD58+eJiIigZMmSjBkzhjp16uDo6HhHW+9Tp04RGBiIUorOnTvTvn379NBo164d77//PiVKlEhv852Z33//nS1btuDg4MD06dPx8PC47/Z169ZN/+G/evVqDhw4kD7KiI6O5sSJExnCIT4+nsDAQMA0chg0aFD6c/7+/oSGhjJ37lyeeOKJLP195XcSDnmBUlCpFXYVW8KxFXj/8zGTrk7m2KElfBrSncK1nuLFlpUpU9jJ0pWKnJTF3/Bzym+//caVK1fYs2cPdnZ2+Pj4kJCQAEBkZCQ3b94kKSmJhIQEChUyjWrTzjnci729PbVr1+arr77iyJEjLF26NNNj9+zZk8mTJ9/xmK2tLUajaTVGo9FIYmJi+nNpxwdTqH333Xe0bdv2vu8v7ZxDZjp16sSoUaPYsGEDkZGR991XQWD15xzyFaXA90lsR2yDbjOp6GHPVLuv6bO/H+9/OZH3Fx3kUnSCpasU+VR0dDTFixfHzs6O9evXc/bs2fTnhg0bxscff0zfvn158803s7zP119/nQkTJjxwJHAvPj4+6es8LF26lKSkpHtu17ZtW6ZOnZr+/PHjx4mNjc328QYOHMgHH3yAn59ftl+bH8nIIS8yGKBGV2yrdoYDv1N1/Xhm3pjA7n2LGbW3B5XqtGdEs4oUd3O0dKUiH0hOTsbBwYG+ffvSsWNH/Pz8CAoKwtfXFzA1ybOzs6NPnz6kpKTQsGFD1q1bR4UKFR647+rVq9/3KqX7GTJkCJ07dyYgIIB27drdMVq43eDBgwk
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"loglog(k, pk_ref, label='N-body')\n",
"loglog(k, pk_i, label='JaxPM Pure PM')\n",
"loglog(k, pk_c, '--', label='JaxPM Hamiltonian GNN')\n",
"legend()\n",
"plt.xlabel(r\"$k$ [$h \\ \\mathrm{Mpc}^{-1}$]\")\n",
"plt.ylabel(r\"$P(k)$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Try to combine large scale PM forces with PP interactions"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from jaxpm.nn import NeuralSplineFourierFilter"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"# Instantiate the neural network\n",
"model = hk.without_apply_rng(hk.transform(lambda x,a : NeuralSplineFourierFilter(n_knots=16, latent_size=32)(x,a)))\n",
"params = model.init(next(rng_seq), jnp.zeros([64]), jnp.ones([1]))"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
"def hamiltonian_from_state_fn(position, momentum, params):\n",
" def update_edge_fn(edges, senders, receivers, globals_):\n",
" del globals_\n",
" # Models particle-particle interactions contribution to grav potential\n",
" distance = jnp.linalg.norm(senders[\"position\"] - receivers[\"position\"], axis=1)\n",
" grav_potential_per_edge = - 1. / (distance + 0.5)\n",
" return frozendict({\"grav_potential\": grav_potential_per_edge})\n",
"\n",
" def update_node_fn(nodes, sent_edges, received_edges, globals_):\n",
" del sent_edges, received_edges, globals_\n",
"\n",
" # Compute gravitational potential by FFT\n",
" kvec = fftk(mesh_shape)\n",
" delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), nodes[\"position\"]))\n",
" pot_k = - delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)\n",
" \n",
" # Apply a correction filter to the potential, will take care of merging PM forces with PP\n",
" kk = jnp.sqrt(sum((ki/pi)**2 for ki in kvec))/2\n",
" pot_k = pot_k * model.apply(params, kk, jnp.ones([1]))\n",
" \n",
" grav_potential = 0.5 * (1 + cic_read(jnp.fft.irfftn(pot_k), nodes[\"position\"]))\n",
" \n",
" # Computes momentum\n",
" momentum_norm = jnp.linalg.norm(nodes[\"momentum\"],axis=1)\n",
" kinetic_energy = 0.5 * momentum_norm ** 2 \n",
" return frozendict({\"kinetic_energy\": kinetic_energy, \"grav_potential\":grav_potential})\n",
"\n",
" def update_global_fn(nodes, edges, globals_):\n",
" del globals_\n",
" # At this point we will receive node and edge features aggregated (summed)\n",
" # for all nodes and edges in each graph.\n",
" hamiltonian_per_graph = nodes['grav_potential'] + nodes[\"kinetic_energy\"] + edges[\"grav_potential\"]\n",
" return frozendict({\"hamiltonian\": hamiltonian_per_graph})\n",
"\n",
" hamiltonian_gnn = jraph.GraphNetwork(\n",
" update_edge_fn=update_edge_fn,\n",
" update_node_fn=update_node_fn,\n",
" update_global_fn=update_global_fn)\n",
"\n",
"\n",
" graph = static_graph._replace(nodes={\n",
" 'position': position,\n",
" 'momentum': momentum})\n",
" output_graph = hamiltonian_gnn(graph)\n",
" return output_graph.globals[\"hamiltonian\"].sum()\n",
"\n",
"hamiltonian_gradients_fn = jax.grad(hamiltonian_from_state_fn, argnums=[0, 1])"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def gnn_nbody_ode(state, a, cosmo, params):\n",
" \"\"\"\n",
" state is a tuple (position, velocities)\n",
" \"\"\"\n",
" pos, vel = state\n",
" \n",
" # Take the derivatives against position and momentum of the hamiltonian of the system\n",
" dh_dposition, dh_dmomentum = hamiltonian_gradients_fn(pos, vel, params)\n",
" \n",
" # Hamilton equations\n",
" dpos_da = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * dh_dmomentum\n",
" dvel_da = - 1.5 * cosmo.Omega_m / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * dh_dposition\n",
"\n",
" return dpos_da, dvel_da"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"# Precomputing a few data stuff\n",
"ref_pos = jnp.stack(poss, axis=0)\n",
"ref_vel = jnp.stack(vels, axis=0)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def loss_fn(params):\n",
" res = odeint(gnn_nbody_ode, [poss[0], vels[0]], jnp.array(scales), cosmo, params, rtol=1e-4, atol=1e-4) \n",
" distance = jnp.sum((res[0] - ref_pos)**2, axis=-1)\n",
" w = jnp.where(jnp.sqrt(distance) < 16, distance, 0.)\n",
"\n",
" # Optional lines to include velocity in the loss\n",
"# vel = jnp.sum((res[1] - ref_vel)**2, axis=-1)\n",
"# wv = jnp.where(jnp.sqrt(distance) < 8, vel, 0.)\n",
" \n",
" return jnp.mean(w) #+ jnp.mean(wv) \n",
"\n",
"@jax.jit\n",
"def update(params, opt_state):\n",
" \"\"\"Single SGD update step.\"\"\"\n",
" loss, grads = jax.value_and_grad(loss_fn)(params)\n",
" updates, new_opt_state = optimizer.update(grads, opt_state)\n",
" new_params = optax.apply_updates(params, updates)\n",
" return loss, new_params, new_opt_state"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"import optax\n",
"learning_rate=0.002\n",
"optimizer = optax.adam(learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"losses = []\n",
"opt_state = optimizer.init(params)"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/200 [00:00<?, ?it/s]\n"
]
},
{
"ename": "RuntimeError",
"evalue": "INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-99-768b1d51abff>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m200\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mlosses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
" \u001b[0;31m[... skipping hidden 7 frame]\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: INTERNAL: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory"
]
}
],
"source": [
"for step in tqdm(range(200)):\n",
" l, params, opt_state = update(params, opt_state)\n",
" losses.append(l)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fb8d44bdd30>]"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAkuklEQVR4nO3deXxU9b3/8dcnCUsIISGQBEjCvoRNtgiCK6KIetW6Vq9atbbU/rTXtrettra9va3ttXq1t61WtHWriruo1wWButYKEiASCATClgXIQkJIAiQk+f7+mIGbxgSCTHJmeT8fDx4zmfM9mc8cTt5z5jvnfL/mnENERMJXlNcFiIhI51LQi4iEOQW9iEiYU9CLiIQ5Bb2ISJiL8bqAtvTv398NHTrU6zJERELGqlWrKpxzyW0tC8qgHzp0KNnZ2V6XISISMsxsR3vL1HUjIhLmFPQiImFOQS8iEuYU9CIiYU5BLyIS5hT0IiJhrkNBb2bbzSzXzHLM7AvnPZpZXzNbZGZrzewzM5vQYtk8M8s3swIzuzOQxYuIyLEdz3n0s51zFe0s+wmQ45y71MwygYeAOWYW7b9/LlAMrDSzN5xzeSdUtYhIGHDOUbqvni3ltWwtr6WuoYlbzhwR8OcJ1AVT44B7AJxzG81sqJmlAsOBAufcVgAzex64BFDQi0jEqG9sYsee/Wwpq2VLeS1byut8t2W+cD8sJb4H3zpjOGYW0OfvaNA7YImZOeAR59yjrZZ/DlwGfGxm04EhQDqQBhS1aFcMzGjrCcxsPjAfYPDgwR1+ASIiwaKyruFIgLcM9KLK/TS3mONpUEJPRqT05sqsDEYkxzEiuTfDk3uT2qdHwEMeOh70pznnSswsBVhqZhudcx+1WH4P8HszywFygTVAUxu/p13+N49HAbKysjTtlYgEvT219by3sYxlG0r5bFslVfsPHVnWIyaKYf3jmJCWwCWT044E+rD+ccT16NrRZzr0bM65Ev9tmZktAqYDH7VYvg+4CcB8b0fbgK1ALJDR4lelAyUBqVxEpIs559hSXseyDaUsyytlVWEVzsHAhJ7MHTeA0QPijwT6oMRYoqMCf3T+ZRwz6M0sDohyztX4788FftmqTSKw3znXAHwD+Mg5t8/MVgKjzGwYvoC/GvjXAL8GEZFO09jUzOrCvSzN282yDWVsq6gDYEJaH26fM4pzxqYyflCfTulyCZSOHNGnAov8LyIGWOicW2xmtwA45xYAY4Gn/H3464Gb/csazew24F0gGnjcObc+8C9DRCRwausb+XhTOUs3lPL+xjKq9h+iW7Qxc0R/vn7aMOZkpjAoMdbrMjvMnAu+7vCsrCynYYpFpCvtqj7Asg1lLMsr5dMte2hoaiaxVzfOHpPCOeNSOX1Uf+J7dvO6zHaZ2SrnXFZby4JyPHoRkc7mnCNv1z6W5fm+TM0tqQZgSL9efG3mEM4dl8q0IX2JiQ79AQQU9CISMeobm1i+tZJleaX8bUMpO6sPYgZTMhK5Y14m545LYURy76Dub/8yFPQiEtaq6hp4P9931P5hfjl1DU3Edovm9FH9+e65ozk7M4X+vXt4XWanUtCLSNjZVlHHsrxSlm4oJXt7Jc3Od9XpxZPTOHdcCrNG9Kdnt2ivy+wyCnoRCXlNzY7VhVVHwn1rue8UyLED+3Db7JHMGZvKxLQEooLkvPaupqAXkZDU1Oz4pKCC13N28n5+GZV1DXSLNk4Z3o8bZg5lztgU0vv28rrMoKCgF5GQUlBWyyuri1m0uoTd+w4S3zOGOZm+UyDPGJ1MnyA+BdIrCnoRCXrVBw7x5tqdvLyqmDWFe4mOMs4cnczPLxrHnLEp9IiJnP72L0NBLyJBqanZ8fHmcl5eVcySvFIaGpsZndqbn1yQyVcmp5HSp6fXJYYMBb2IBJXNpTW87O+aKaupJ7FXN645OYMrpmUwIS24x5QJVgp6EfHc3v0N/O/nvq6Zz4uriY4yZo9J5opp6czOVNfMiVLQi4gnKmrrySncy6I1JSzNK6WhqZnMAfH89MKxXDI5jeT48L6IqSsp6EWkU1UfOMTm0hryS2vYXFpL/u4aNpXWsKeuAYCkuO5ce8pgrpiWzvhBCR5XG54U9CISEPsbGiko+78g31Ray6bSGnZVHzzSJq57NKNS45kzNoXRqfGMHdiHk4cm0T0m9AcOC2YKehE5bg2NzfxtQynrdlaTv9sX6EVV+zk86nn3mChGpfTmlOH9GJ0az5gBvRmVEk9aYmzEXp3qJQW9iHRYQ2MzL68q5qH3CyjZe4CYKGN4chwT0xO4Ylo6o1PjGZ3amyH94oJmGj1R0ItIB9Q3NvFSdjEPf7CFkr0HmJyRyN1fmcCpI/ur2yUEKOhFpF31jU28uLKIP32whV3VB5k6OJHfXDaRM0b11/nsIURBLyJfcPBQEy+sLOLhD7awe99Bsob05d4rTuK0kQr4UKSgF5EjDh5q4rnPClnw4RZK99UzfWgS9181iVkj+ingQ5iCXkQ4eKiJZ1cU8siHWyirqWf6sCR+99XJzByugA8HCnqRCHagoYlnV+zgkY+2Ul5TzynDk/j91VOYOaKf16VJACnoRSLQ/oZGnl1eyCMfbaWitp5ZI/rx4DVTmDFcAR+OFPQiEebt3F38xxvrKa+p59SR/fjTnKlMH5bkdVnSiRT0IhGibN9Bfv76ehav382EtD48fO1UsoYq4COBgl4kzDnneGlVMXe/mcfBxmbumJfJN08fRky0LnSKFAp6kTBWVLmfnyzK5ePNFUwfmsQ9l09keHJvr8uSLqagFwlDzc2Ov366nXvfzceAX10ynmtnDNGAYhFKQS8SZgrKarnjlbWs2lHFmaOT+fWlE0jv28vrssRDCnqRMHGoqZlHP9rK75dtplePaB64ahKXTknTBU+ioBcJB+tKqvnRy2vJ27WPCyYO4D8vnqCp+OQIBb1ICDt4qIk//G0zj3y0laS47iy4birzJgz0uiwJMgp6kRCVvb2SH72ylq3ldVw5LZ2fXjiOhF7dvC5LgpCCXiTE1NU3cu/ijfx1+Q4GJcTy9M3TOX1UstdlSRBT0IuEkFU7qrj9+TWU7D3ADTOH8sPzxhDXQ3/GcnTaQ0RCQHOz4+EPt/DA0k0MTOjJS9+aqeELpMMU9CJBrnTfQb7/Yg6fFOzhX04ayG8um0ifnuqLl45T0IsEsfc3lvHvL33OgYYm7r38JK7MStd58XLcOhT0ZrYdqAGagEbnXFar5QnAM8Bg/+/8b+fcE/5lTUCuv2mhc+7iwJQuEr7qG5u4d3E+j/19G5kD4nnwX6cwMiXe67IkRB3PEf1s51xFO8tuBfKccxeZWTKQb2bPOucagAPOucknWqhIpNhWUcd3nlvNupJ93DBzCD++YCw9u0V7XZaEsEB13Tgg3nyfKXsDlUBjgH63SMR4dXUxP3ttHd1ionj0+mnMHT/A65IkDHQ06B2wxMwc8Ihz7tFWyx8E3gB2AvHAV51zzf5lPc0sG1/w3+Oce62tJzCz+cB8gMGDBx/XixAJdbX1jfzstXUsWlPC9GFJ/P7qyQxMiPW6LAkTHQ3605xzJWaWAiw1s43OuY9aLD8PyAHOBkb423zsnNsHDPGvOxx4z8xynXNbWj+B/83jUYCsrCx3Aq9JJKTkFlfznedWU1i5n++dM5rbzh5JtIYTlgDq0BQzzrkS/20ZsAiY3qrJTcCrzqcA2AZktlp3K/ABMCUglYuEuOZmx58/2splD39CfWMzz8+fye3njFLIS8AdM+jNLM7M4g/fB+YC61o1KwTm+NukAmOArWbW18x6+B/vD5wK5AWufJHQVFFbz01PruTXb29g9pgU3rn9dE3QLZ2mI103qcAi/7m7McBC59xiM7sFwDm3APgV8KSZ5QIG3OGcqzCzWcAjZtaM703lHuecgl4i2t83V/C9F3OoPnCIX10ynutOGaJz46VTHTPo/V0uk9p4fEGL+zvxHem3bvMPYOIJ1igSFpxz/G7pJv74fgEjknvz9M3TyRzQx+uyJALoyliRLtDU7LhrUS7Pryziymnp/PKSCcR217nx0jUU9CK
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot(losses)"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
"res = odeint(gnn_nbody_ode, [poss[0], vels[0]], jnp.array(scales), cosmo, params, rtol=1e-4, atol=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkwAAAJBCAYAAAC0+uodAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd3hV1dLGZ53kpBJSKaGEEgJBQIpKRFBBBBFUsKDCRbFd7Oi1XNHPem1Yr4qKoiigYsGr2FBBBCli6EgvoYRAAiQhnfT9/SHsPfOGc04g4djm9zw+rmGdsvfas9fZWfOuGWNZFimKoiiKoiiecf3eB6AoiqIoivJHRx+YFEVRFEVRfKAPTIqiKIqiKD7QByZFURRFURQf6AOToiiKoiiKD/SBSVEURVEUxQd1emAyxgwyxmw2xmwzxoyrr4NS/l6oHyl1RX1IqQ/UjxRvmOPNw2SMCSCiLUQ0gIgyiGgZEY2wLGtD/R2e8ldH/UipK+pDSn2gfqT4IrAO7+1JRNssy9pORGSM+YiIhhKRR+eKiAi3YmOjbLu6OtRuG1MhXmtZ8tCqq8Oh383eWy760A4J2WG3y8vj4XOChV1Z2RA+q9puu1yHyBvV1WGsLR9EXbCWZ1nyHwICith7Q0VfeXkZvDYIPjuAfa783vDwTfBZzZlVJfr4uRIRVVXJseDXxOUqsdu5uVlUXJxn6Pg4Jj8KCgyxwoIjbLs6wLl+ldXSh0Kq5fnlw7hFVTl+UhjgFn3VJE8nNizWbpeXF4m+ykp5fUoJrgE7xsIq6Zsu8vwHCx4/DrDlkveI2+34XxV8Twgc88FA6fcNqpyxCwAfqmDjTUQUwD4rpLpS9OEYh0N/oMXuJzj3jLKCbMuyGtGxc8xzUaTLbTVhY3CQXaMAOK6oKnl9cwPk2DWpdOaFIpf0oyojr1peUAO77YJrFAzXO9iSdqVx5gyD18jAfMLOIRr8sxJeWwS+XxjozD8NK0pEXwO4x8pNgLD5MaNnZ7BzJyKKZ5/ttuTccwg+NwrGio+ri43F7uoKyq2u8stcREQU7Qq0mjGf52PpgmsUAWN3EO6VuMpSu30I7m28ZnvYvY5jEwRjGQR+Vc3GDq8RXk/DXhEDfoS+jb6fHRhit5tWyt9OnDeq0H/ZOZTBMRXAuDVk5x8C90w5fG5olbwGLnN0V0mvrqTso/hRXR6YmhPRbmZnEFGKtzfExkbRQw/datslJZ3sdlDQXvHaiorGwi4qOkXY/MEnODhd9AUH7xZ2+/ZX2+3du+Uqa2lpW2EfPDhQ2AEBzk0dFraevFFc3JV9rrwwoaFy7KurQ4QdEZFqt/m4EBFlZGwXdmRkgrCD2ERkwc1y2mlnCHvXrifststVKPoCA6WdlzdA2BUVzm9ZePgau/3f/15HdeCY/CgsOIL6JA+z7fLY9nY7q1D6UHJpnrC/gXEblLfTbs+LaCb6SmDCuqIr96GfRV8e+xwionUwsZwRk2i3f8iXvhoGNzAHj98N05vVoKmwmzbtZrfz4XuS0hcJe0a09Ps+RVl2OxKO/0Ci9IPwnfM9HiOOcUrxfmHzHwU893u2fbuLjo9jnouaBAbThCZdbPvTqDZ2OxJ+2IbC9f0wpp2w79r3q91eHN5E9OXDg+nMlr3tdli+PN3EsgKvNv8BwgeMTPYDSiSv4WV5O0Qf/xwiosXgRwuZH52TuVL09WZ+QkS02y3/kG1b7swh+BB3b4vThf0Q++ym8GC2LjRG2HgN+IN5GPvxPT/veF2IiI7Dj5oFBNGHsUm2zccS/btvUaawuc8REV2fs9lur4Xzz4Frdn+z0+z2MBibhIpiacMfS3xuw2uUDg+1bvawNeJgmujDP47Qj96OS7bb47JWi74OMG/gZ0Wyh6DtwfIP99kNWwh7YEGGx8/F8+nGXktEFAzfe4Qz4bfkCCdc9G2MGWOMWW6MWV5YWOz7DYoCcB8qZz+4inIscD/Kr/b8oKoo3uB+dBBWSpS/NnVZYdpDRC2Z3eLwvwksy5pERJOIiNq2jbXCwpzVzbZtndWOgAD5JFhWli/sXbseF3Zu7oV22xjptFFRc+C1F9jtqioZXti3T67AREZuE/ahQ85fk+XlchUCV5yqq52n1chI+Rc+/xwioqZN3xY2/+zcXPmXZSj8teGC1Q++qoTHn5U1RtgNGzorDTwMiH1ERB073ifsDRtetNsREb/Ybbc7h+qATz/iPhQR3tjKZ0u/HRL62O3nxz4nPviuu04TdgosTe9vdJLdzoe/cM8Ok5GhoqJ9Hk8gq1T6qhtWFfiKYDJbjSAi6pMyVtjXXnu73R5xc7Loax3VStqt+wq7rMz56349WwUiIgrveImwS/avEzZfGbrjjHtF3xlJzwr72UlOeLIaPvem834Rdrt2Xwv77bffsdupm78SfbTtWzpOjnkuah3c0NrN/vrcFd/dbg8bJsP2k9/8XtiREAL5lK3WhcEPaGq4XCnfxP76HwgP//i5afCXdTpbzUksl6vB25v3FHb//ufa7WlTnxB9CWXyvUtg5SeUrVbhaz+MThR2P1g12RwSZbe/6y5XnnvAquYLbAUGV+LOYCu6RERzDspV9tgNn9rtEbnOnFfmkuGbY+SY/ahdcITFV0BmRziSh1MulffNMx/fLmxcnZ3DVk7Qj3BVJaXkgPNaCEPhqlENP2J+H18uFzFmgx8lJjoRlweWvir6cGXrhcZdhJ3OQvm4gjgfVvQ7H8olT/yz1VnC7g/z3qLlb9rtiWxOJyIyzU4V9sZc+fvIV7746imG8o5QlxWmZUSUZIxpY4wJIqIriejLOnye8vdE/UipK+pDSn2gfqR45bhXmCzLqjTG3EZE3xNRABG9Y1mWd5GPogDqR0pdUR9S6gP1I8UXdQnJkWVZs4hoVu3f4aKqKmc5sLjYWa4uLW0vXhkUJJd6MXwUG/u53W7YcLHoCww8KGweOtu79w7RFxoqF9lyc2X4y5g81pZL9Sg2Dw11lvvCw2XopaIiWtj798sl58jI+XbbDQJO3PFkYLkwKsrZCdeo0XT4Xhla4mMTFfWd6GvVSi7H5udnC5vvkisrc0I4PBR5PByLH5mAIAqOdFbNu3VzxHknn/yZeO0338jV9H/+0/Pu4Jsbdxb2hRdK4f3ZZztLxv/+dwfRl7l3ubAjy6U+JiPDCVM1hPDqrl0LhP3YY1fY7fxAucOxEkI4mzbNFHaHDhfZbVzSLwbx9WkghkxKGmy3zzhDjkVFhVw+zw88225jWPCaa4aS5GlhNWjgfM+TT44WffJsj41jnYt2BzWgO1o64fiRbNk/KmqieO1546TYftIkueyfxsJueO926zZK2Keyuer771eIvgTwhXgIY/CQHIqCU+C9VW//6BhwTGkhkcJOAkHuliYnU21Bkf8Bdh9deLb83uXL5RzujnDm04SwONF3882PCLtrV7nQc9VVjv9esvMnu50Hot5j5Vj9KDswRIibKb6H3UxIeFK8NmT0v4X99v+mCZuHzlJAdsHvbSKiBBZ6euGn/4i+h/aCX4EInIv+V4bL34exabOFXbLjR/LEJhZ+JSK6OXujsMezzQO4iSYdNgvgpoWJcR2dYxr4f6Jv1SqY55lUoSmE+sbeKOefCy7YKeyhQyfb7dlsY8zeQvn8cQTN9K0oiqIoiuIDfWBSFEVRFEXxgT4wKYqiKIqi+OC4S6McD23axFuPPHKNbXNdUmI7mQrAZWTMc/duuaW6uNiJs7tcUtuRk3OxsHmSSJ7WgKhmhnGuzSHCjOLytUUsUSURkcUydJeXy5hsbKyMBWPCzOLi1na7ArQLmIwSNU1xcU4agubNXxB9eD579txtt9u0uUf0YRb0goI+HvtDQpxtvi+8cDulp2853uy6x4Q7It6K6natbfdt6lyDqKjW4rVvvSVzzn39teyfPt3ZJltQIJOdNmzYUtiXXPKK3c7Kktully6NFfZe0DSdfvqd7L2rRV8aaAbimB4iJ2eL6GsC2pIiSCIYzbZtYxqEjZCAsA+8l+sAToXt3wG7lwibb2lfAdqv7eC7JZCllyeujANN1qLlE1dYliX3AZ8gmodEWTe
"text/plain": [
"<Figure size 720x720 with 16 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"figure(figsize=[10,10])\n",
"for i in range(16):\n",
" subplot(4,4,i+1)\n",
" imshow(cic_paint(jnp.zeros(mesh_shape), res[0][::2][i]).sum(axis=0), cmap='gist_stern',\n",
" vmax=cic_paint(jnp.zeros(mesh_shape), poss[::2][i]).sum(axis=0).max(),vmin=0)"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/local/home/flanusse/.local/lib/python3.8/site-packages/jax/_src/lax/lax.py:488: ComplexWarning: Casting complex values to real discards the imaginary part\n",
" return _convert_element_type(operand, new_dtype, weak_type=False)\n"
]
}
],
"source": [
"k, pk_ref = power_spectrum(\n",
" compensate_cic(cic_paint(jnp.zeros(mesh_shape), poss[-1])),\n",
" boxsize=np.array([25.] * 3),\n",
" kmin=np.pi / 25.,\n",
" dk=2 * np.pi / 25.)\n",
"\n",
"k, pk_c = power_spectrum(\n",
" compensate_cic(cic_paint(jnp.zeros(mesh_shape), res[0][-1])),\n",
" boxsize=np.array([25.] * 3),\n",
" kmin=np.pi / 25.,\n",
" dk=2 * np.pi / 25.)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, '$P(k)$')"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEQCAYAAABbfbiFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAABAsklEQVR4nO3dd1xV9R/H8de5bBRRUFyoqLhZKu5Z7nI03JYjt5VlpY1fWVlWlmmmmVppZeXI3HukuRNUVHArqDhQUBnKuvd+f39cIBVUxoXL+Dwfj/sQ7j33nM814813nO9XU0ohhBBC3Etn6QKEEELkPxIOQggh0pFwEEIIkY6EgxBCiHQkHIQQQqQj4SCEECIda0sXYC6lS5dWHh4eli5DCCEKlIMHD0Yqpco8+HyhCQcPDw8CAwMtXYYQQhQomqZdyOh56VYSQgiRjoSDEEKIdCQchBBCpFNoxhyEKIiSk5MJDw8nISHB0qWIQs7e3h53d3dsbGwydbyEgxAWFB4ejpOTEx4eHmiaZulyRCGllCIqKorw8HCqVq2aqfcU+G4lTdO6aZo2Lzo62tKlCJFlCQkJuLq6SjCIXKVpGq6urllqoRb4cFBKrVFKjXB2ds7W+8Mi75CoN5i5KiEyT4JB5IWs/jsr8OGQE0ajYu8PrzHn8zdZvPcUSXqjpUsSIs9pmsabb76Z9v3UqVP56KOPMjy2ePHi2b5OWFgYXl5e2X6/yFtFOhw0jHR2vc5rhgW03dSJGV+8w5//niPZICEhig47OzuWL19OZGSkpUsR+UjRDgedFS4jVqMGrcGhTDXG6+fSdF0npk6ZyF8BYeglJEQRYG1tzYgRI5g+fXqmjh83bhz16tWjXbt23LhxA4CgoCCaNm2Kj48Pzz77LLdu3QLg4MGD+Pr64uvry3fffZd2jtatWxMUFJT2fcuWLTly5Ij5PpTIMZmtBGhVW+P88jbU2W04r/+Qd2/N5NyaZXy2tT8+nYbQzc8dK530C4vc9fGaEI5fiTHrOetWKMGH3eo99riXX34ZHx8fJkyY8Mjj7ty5g7+/P9OnT2fSpEl8/PHHzJo1i4EDBzJz5kzatGnDxIkT+fjjj/nmm28YMmQIs2bNonXr1owfPz7tPEOHDuXnn3/mm2++4fTp0yQkJODr65vjzyvMp0i3HO6jaWg12lNi7G5Un99xK+XExMSvqb2yCx9/OYU1QZcxGmW/bVE4lShRgoEDB/Ltt98+8jidTkefPn0AeOGFF9i9ezfR0dHcvn2bNm3aADBo0CB27tzJ7du3uX37Nq1btwbgxRdfTDtPr169WLt2LcnJycyfP5/BgwfnzgcT2SYthwdpGlqdrjjVegpjyArcN01iUtznHFm+hA82D6Rl5z508iqPTloSwswy8xt+bnr99ddp0KABQ4YMAcBgMNCwYUMAunfvzqRJk9K9J7szrRwdHenQoQOrVq1i6dKlHDx4MPuFi1whLYeH0enQeT9P8XEHMXb/jhrFE5l89yNKL+vBO9NmsznkGkpJS0IUHi4uLvTu3ZuffvoJACsrK4KCgggKCkoLBqPRyLJlywD4448/aNmyJc7OzpQqVYpdu3YBsHDhQtq0aUPJkiUpWbIku3fvBuD333+/73rDhg1j7NixNGrUiFKlSuXVxxSZJOHwOFbW6Bq8gOMbQRifmoaX422+jHsPx8XPMX76T/x9MkJCQhQab7755iNnLRUrVowDBw7g5eXF33//zcSJEwH45ZdfGD9+PD4+PgQFBaU9v2DBAl5++WX8/PzS/X/SsGFDSpQokdZSEfmLVlh+sPn7+6s82c8hOR5DwHySd0zFPukm2wz1WeP6Es8+1YXWNUrLDU0iS06cOEGdOnUsXYZFXLlyhbZt23Ly5El0Ovk9NS9k9O9N07SDSin/B4+V/yJZZeOAVfOXsX/zGIYnJ9LS/hzf3H6VuIX9eX3WYvacjZSWhBCP8euvv9KkSRMmT54swZBPScshpxKiMeyZhWHvd1gb7rLS0ILt5V5iQJe2NK3mmvf1iAKlKLccRN6TlkNesnfGqt3/sH3jGMZmr9LdJpDpN4YROn8or36/msCwm5auUAghskzCwVyKuWLd6ROsxx2FRsPobbObryOGEPzjSF6Zt55DF29ZukIhhMg0CQdzcyqL9dNfYfV6EDq/AQy0+ZuvrgziwLxXePXHLRwNv23pCoUQ4rEkHHKLszvWz3yL7tUArL2eYaT1Or4If4Htc17n1fk7CLki+08IIfIvCYfc5lINm54/oI3Zj23tTrxmvYLJF/uz7ru3eO2X3Zy8Zt61dITIjuwsxR0WFoaDgwN+fn7UrVuXUaNGYTQaCQsLQ9M03n///bRjIyMjsbGx4ZVXXkl3np9//pkyZcqkneeHH37I0Wd5GCsrK/z8/PDy8qJXr17cvXsXMN3l/cILL6Qdp9frKVOmDF27ds2VOgoKCYe84lYbm76/wshdOHi2ZILNUiae78efM9/h9d/2cSYi1tIVCpFl1atXJygoiKNHj3L8+HFWrlwJQNWqVVm3bl3acX/++Sf16j18eZA+ffoQFBTEjh07eO+994iIiMjU9fV6faZrdXBwICgoiODgYGxtbZkzZw5gurEvODiY+Ph4ALZs2ULFihUzfd7CqsCHQ4HbJrS8DzYvLIVh2yhRtQEf2PzOu2f68cu3H/DGogOcuxFn6QpFERUXF0e7du1o0KAB3t7erFq1CoCAgAB8fHxISEjgzp071KtXj+Dg4Pvea21tTfPmzTl79ixgWjupTp06pE4vX7JkCb17935sDW5ublSvXp0LFy4wePDgtKU64L/WzY4dO2jVqhXdu3enbt26GAwGxo8fT6NGjfDx8WHu3LmPvU6rVq3SagV46qmn0sJs0aJF9OvX77HnKOwK/MJ7Sqk1wBp/f//hlq4lS9z9sRm8CsJ247JlEp9eXsDlk2uZEfwsRp8+vNq+DlVci1m6SpGXNrwD146Z95zlvKHLF5k61N7enhUrVlCiRAkiIyNp2rQp3bt3p1GjRnTv3p3333+f+Ph4XnjhBby8vAgLC0t77927d9m2bdt9i/P17duXxYsXU7ZsWaysrKhQoQJXrlx5ZA3nz5/n/PnzeHp6PvK4Q4cOERwcTNWqVZk3bx7Ozs4EBASQmJhIixYt6NixI1WrVs3wvXq9ng0bNtC5c+f7ap00aRJdu3bl6NGjvPTSS2lrRRVVBT4cCjyPltgM2wTntuG25RO+jJhH6PHVTD/WEzvfnrzSrhaVXBwtXaUoApRSvPfee+zcuROdTsfly5eJiIigXLlyTJw4kUaNGmFvb3/fst7nzp3Dz88PTdPo0aMHXbp0SQuNzp0788EHH1C2bNm0Zb4fZsmSJezevRs7Ozvmzp2Li4vLI49v3Lhx2g//zZs3c/To0bRWRnR0NGfOnEkXDvHx8fj5+QGmlsPQoUPTXvPx8SEsLIxFixbx1FNPZervq7CTcMgPNA0822NTvR2cWk+lrZ/wTeQsTgWv4rOgXpRs8CyvtKtBxZIOlq5U5KZM/oafW37//Xdu3LjBwYMHsbGxwcPDg4SEBACioqKIi4sjOTmZhIQEihUztWpTxxwyYmtrS8OGDfn66685fvw4q1evfui1+/Tpw6xZs+57ztraGqPRtBuj0WgkKSkp7bXU64Mp1GbOnEmnTp0e+flSxxwepnv37rz11lvs2LGDqKioR56rKCjwYw6FiqZB7aexHrMXes6nuost39tMo/+RQXwwdTofrDjGtegES1cpCqno6Gjc3NywsbFh+/btXLhwIe21kSNH8sknnzBgwADefvvtTJ/zzTffZMqUKY9tCWTEw8MjbZ+H1atXk5ycnOFxnTp14vvvv097/fTp09y5cyfL13vppZf48MMP8fb2zvJ7CyNpOeRHOh14PY91nR5wdAl1tn/O/JgpBB5eyVuHeuPZqAtj2lbHrYS9pSsVhYBer8fOzo4BAwbQrVs3vL298ff3p3bt2oBpkTwbGxv69++PwWCgefPm/P3331SrVu2x565Xr94jZyk9yvDhw+nRowe+vr507tz5vtbCvYYNG0ZYWBgNGjR
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"loglog(k, pk_ref, label='N-body')\n",
"loglog(k, pk_i, label='JaxPM Pure PM')\n",
"loglog(k, pk_c, '--', label='JaxPM Hamiltonian GNN')\n",
"legend()\n",
"plt.xlabel(r\"$k$ [$h \\ \\mathrm{Mpc}^{-1}$]\")\n",
"plt.ylabel(r\"$P(k)$\")"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAAEACAYAAABF+UbAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAA6r0lEQVR4nO3de3wddZ34/9d7zsk9aZM2bXpLb9ByKXfaQkWl3LRc1qrgCroq/hDwh7rrZV0uuq7rV2ErXnZh0RUvi7Iqy+IqfKFQQUhFRCiXQmkppRTatE0paZO0uZ7LvL9/zElJc52TzMmcSd7Px2MeJDNzZt7tlHc++czn8/6IqmKMMSaanLADMMYYM3yWxI0xJsIsiRtjTIRZEjfGmAizJG6MMRFmSdwYYyLMkrgxxmRJRH4mIntF5KUBjouI3CIiW0XkRRE5JVexWBI3xpjs3QGsGOT4+cCCzHYV8MNcBWJJ3BhjsqSqfwT2D3LKSuAX6vkLUCki03MRiyVxY4wJ3kygvsf3OzP7AhfPxUX9qK6u1rlz54Z1e+NDW1sbZWVlYYdhciwKz/nZZ59tVNUpI7nGe88q03370/7u92LXRqCzx67bVfX2kdw/V0JL4nPnzuWZZ54J6/bGh7q6OpYvXx52GCbHovCcRWT7SK/RuD/NU2tm+Tq3YPprnaq6eAS32wXU9vh+VmZf4Kw7xRgzTihpdX1tAbgP+HhmlMrpQIuqNgRx4d5Ca4kbY8xoUiCFv+6UoYjIr4HlQLWI7AT+CSgAUNX/AFYDFwBbgXbgk4HcuB+WxI0x44KipAMqva2qlw1xXIHPBHKzIVgSN8aMGy5jb/0ES+LGmHFBgbQlcWOMiS5riRsTIeu2PM+6xq9ysPk4rr1gVdjhmJApBNYnnk8siZsx58aHruDq4//IyRVw6gTB1S08s+0hygv+h6NrF4YdngmJoiTHYEvcxombMeW7D36TL5ywlgkO7E45PNxcxQEXTintYGbsAn786E/CDtGERSHtc4sSS+JmzLj36fu58sQ7KBDhjh1HMLv2Fd579FM4pU/zaMtESkV474Kbww7ThEQB1+cWJZbEzZixdMYXKRPhof3VXHnaQ4f2V1ZVce5R69iejDGrQLnz2XNCjNKER0j73KLEkrgZE2574q+oicNriTgXLXqy33Oe2/UtulyXD83Yzq8fv2uUIzRhU8BVf1uUWBI3Y8Jlc18mpcra17404DmXLLuYu3bNpVAcTpn9z6MYnckHCiRwfG1REq1ojenHz585l8qYw6bOYq48+1ODnvvJJY+wL63MK0xx79P3j1KEJl+4Kr62KLEkbiKtuamJD8x4g6S6bH/rJl+fWb1nNnERKiv/KcfRmXzizdi0PnFj8sr9r19MuePwXFs5K5de5OszK+ffQ6frcmp5C81NTTmO0OQLRUjj+NqiJFrRGtPLe6bWk1Klpenbvj9TWVXFix3llDoOd23+SA6jM/nGulOMySOrHrieyTHYmYyxYsl7svrslvovkVbl/bWv5Sg6k2+sO8WYPHP+sf8XEeF/tpyc9Wc/vvxj7ErGmBJTbl3z/RxEZ/KNIiQ17muLEkviJpLq9zSwsLiTg67Ltef+eljXeGjXfESERfN+GXB0Jl9ZS9yYPPHHho9QKA5PH6wa9jUunPdTkupyfFlLgJGZfKUqpNXxtUVJtKI1JuPcKbtIqZI+8K1hX6N22nTeTMWocuCeJ38TYHQmX7mIry1KLImbyPlF3Z1Mjil7Uk7WLzR7e7xxBo4IhRO+E1B0Jl95LzZtiKExoZs27VYcER7ZM3vE11pY9j3SqiyZuC+AyEx+s+4UY/LCKRVNJNXlvNpfjPhaSxaeTGNaqI4p67Y8H0B0Jl8pkNSYry1KLImbSLltzS1UOrArGad22vRArvlMy2RiImxp/WIg1zP5yWZsGpMHjpv3cxwRHtp5RGDX7Drw97iqLKtuCOyaJj+56vjaoiRa0Zpx7/iyFhLq8pFj/iuwa16y7GIOukpNPBXYNU3+sRebxoTsuw9+k4mOUJ8ooLJq+OPD+7MrWUyJ43Dzg18P9LomfyhCWv1tUWJJ3ETG4iN+i4jw0I4FgV/7Tw3eSJfj5z4Y+LVNflDFpt0bE6ZFZQdIqMvHjvt54Nc+o+Y7pFVZVN4c+LVNvvA30WdMTvYRkRUi8oqIbBWR6/o5PltEHhOR50XkRRG5IPhQzXh2z5O/odKBN1OxwLtSAI6ffywtLlTH0oFf2+QHhfE5TlxEYsBtwPnAscBlInJsr9O+CtytqicDlwI/CDpQM74VVHwXR4Q/Nc7I2T22d5VQ5Djc+MA1ObuHCdd4fbG5FNiqqttUNQHcBazsdY4CEzJfTwR2BxeiMbC4spG0KgvLvpeze9TtWAjAGQueytk9THgUfwtCjMVFIWYC9T2+35nZ19PXgb8RkZ3AauBzgURnDLC5fgtTYsq+tLBkYfa1w/3666NvJaXKwtKDObuHCddYbIkH9Rr2MuAOVf2uiCwD7hSR41TV7XmSiFwFXAVQU1NDXV1dQLc3udDa2poXz+i10m+xcLaw/mAVxTmOp/BIYVLM5d4HfsvEsuD73vNRvjznXPMWhYjWlHo//CTxXUBtj+9nZfb1dAWwAkBVnxSRYqAa2NvzJFW9HbgdYPHixbp8+fLhRW1GRV1dHfnwjGbv+DSuKvUNV3Ll2bmN56nXSlhS1sFLsdV8ZfmPc3qvfDHUc35u2wlMiCU5cs7LoxdUDihEbjamH37+ROuABSIyT0QK8V5c3tfrnB3AOQAicgxQDLwVZKBm/JoeT3HQVa48+1M5v9efM9P5T5u3Puf3iooFxe3MKkyGHUYgxuXKPqqaAj4LrAFexhuFslFEviEi78uc9iXgShF5Afg1cLmqaq6CNuPHTWs+QZHjsLWrbFTud/78f/FeoJYeGJX7RUGxCLGIJbb+qEpgtVPyadi1r98tVHW1qi5U1SNU9VuZfV9T1fsyX29S1TNU9URVPUlVf5+rgM34ct6C9QA8vOWUUbnf0bULaXWVyTZeHIB/e+hmYiLERMZEqd4gxonn27DrsddBZMaUI4ra6HRdbljxn6N2z4ZkEcUi/PjRn4zaPfNVwYS6Q18/+urd4QUSACWw5dnyati1JXGTt35RdycVjtCQGt1aFi80TUFEqJh056jeNx8tnPz22AQp2RZiJCOnCEk35msDqkXkmR7bVT0ulVfDri2Jm7xVOeUHOCL8+a3e/3/kVlfz5agqJ1bZu/kZRe2Hvq6u2B9iJMHIYpx4o6ou7rHdnuWtuoddzwIuwBt2nZN8a0nc5K1TJuwjrcrRFaO7iPHlZ11OhyrT4olRvW8+qoq/PSplSmlriJGMXIAzNv0Ou74bvGHXeCP2qgP6oxzGkrjJS81NTUyJuzS55HSW5kD2p2OUO8KGbZtG/d75pMxx6XK9OXtVhV0hRzNyLo6vbQh5NezakrjJSz/f8EkKxGFz24ShT86BV9onEBPh0Z03hHL/fFEsQrPrpYkJEf/NRJVAFoXIt2HX0ap+bsaNs2u3AvDk1nN598LRv/+6107lnFMe5fSZr4/+zfPErWu+z2dOFN5MFjEl1klpxIddKkLKDWbavaquxnth2XPf13p8vQk4I5CbDcFa4iYvzS7qpN11ufaCVaHc/5plN5FSZXZRRyj3zwdO+aMA7GgvJ41S4kQ7icM4nbFpzGi7bc0tlIuwO1kQWgyVVVUccGGi4w598hh1ZGZ44Za9s0gpFEm0J2F7tVPGZylaY0bVvNpfISL8ae+sUOPYkyyk2HG4dc33Q40jLLOKvdEoC8o+QVKhMOJJHIKbdp9PohWtGRdOrNhPWpUlk/8t1DheaJoKQPXUe0ONIyyT4kmS6rJy6UUkkDHxAm3crrFpzGhpbmqiOjO08Pj5vctRjK7Wxg8BcMI4nfRT5rh0ZHqTulyHuEQrufUW1OiUfGNJ3OSVn73wKW9oYXtF2KFw9XnX0Om6TCuI9tC64SoWoTUzvLDddYgB9Xsawg1qBLpHp/jZosSSuMkr5855BYAntpwVciSeZtehwvF+QxhPfvzoT4iLsC9VCEBbugA
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"col = cm.viridis(np.linspace(0.,1.0,len(scales[::2]))) \n",
"\n",
"kvals = jnp.logspace(-2.,0,100)*sqrt(3)\n",
"\n",
"for i, a in enumerate(scales[::2]): \n",
" semilogx(kvals, model.apply(params, kvals , jnp.atleast_1d(1)), color=col[i])\n",
"\n",
"sm = plt.cm.ScalarMappable(cmap=\"viridis\", norm=plt.Normalize(vmin=0., vmax=1))\n",
"plt.colorbar(sm, label='a')\n",
"\n",
"#xlim(kvals[0], 1)\n",
"#ylim(0,5)\n",
"grid()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}