JaxPM/notebooks/ParallelNbody-LPT.ipynb

428 lines
1.4 MiB
Text
Raw Normal View History

2022-10-19 17:37:38 -07:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "40ba1e08-69c0-494d-9f6f-879849813bf8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%pylab inline\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.lax as lax\n",
"from jax.experimental.maps import xmap\n",
"from jax.experimental.maps import Mesh\n",
"from jax.experimental.pjit import PartitionSpec, pjit\n",
"from jaxpm.pm import cic_paint, cic_read, fftk\n",
"from functools import partial\n",
"import jax_cosmo as jc"
]
},
{
"cell_type": "code",
"execution_count": 115,
"id": "72822b71-95cf-452e-8e30-e3c2e77b8156",
"metadata": {},
"outputs": [],
"source": [
"nc=512\n",
"boxsize=1024. # Mpx/h\n",
"halo_size=16"
]
},
{
"cell_type": "code",
"execution_count": 151,
"id": "69c5e0ec-e82b-4ee1-94e1-4da546d26155",
"metadata": {},
"outputs": [],
"source": [
"def cic_paint(mesh, positions):\n",
" \"\"\" Paints positions onto mesh\n",
" mesh: [nx, ny, nz]\n",
" positions: [npart, 3]\n",
" \"\"\"\n",
" positions = jnp.expand_dims(positions, 1)\n",
" floor = jnp.floor(positions)\n",
" connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], \n",
" [0., 0, 1], [1., 1, 0], [1., 0, 1], \n",
" [0., 1, 1], [1., 1, 1]]])\n",
"\n",
" neighboor_coords = floor + connection\n",
" kernel = 1. - jnp.abs(positions - neighboor_coords)\n",
" kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] \n",
"\n",
" neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), mesh.shape[-1])\n",
"\n",
" dnums = jax.lax.ScatterDimensionNumbers(\n",
" update_window_dims=(),\n",
" inserted_window_dims=(0, 1, 2),\n",
" scatter_dims_to_operand_dims=(0, 1, 2))\n",
" mesh = lax.scatter_add(mesh, \n",
" neighboor_coords, \n",
" kernel.reshape([-1,8]),\n",
" dnums)\n",
" return mesh\n",
"\n",
"def cic_read(mesh, positions):\n",
" \"\"\" Paints positions onto mesh\n",
" mesh: [nx, ny, nz]\n",
" positions: [npart, 3]\n",
" \"\"\" \n",
" positions = jnp.expand_dims(positions, 1)\n",
" floor = jnp.floor(positions)\n",
" connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], \n",
" [0., 0, 1], [1., 1, 0], [1., 0, 1], \n",
" [0., 1, 1], [1., 1, 1]]])\n",
"\n",
" neighboor_coords = floor + connection\n",
" kernel = 1. - jnp.abs(positions - neighboor_coords)\n",
" kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] \n",
"\n",
" #neighboor_coords = neighboor_coords.reshape([-1,8,3]).astype('int32')\n",
"\n",
" neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), mesh.shape[-1])\n",
"\n",
" return (mesh[neighboor_coords[...,0], \n",
" neighboor_coords[...,1], \n",
" neighboor_coords[...,3]]*kernel).sum(axis=-1)\n"
]
},
{
"cell_type": "code",
"execution_count": 152,
"id": "56d06cdd-b1f3-434f-85e9-01ce7bb5e2e5",
"metadata": {},
"outputs": [],
"source": [
"# Defining the main operations\n",
"@partial(xmap,\n",
" in_axes={0:'x', 1:'y'},\n",
" out_axes=['x','y',...],\n",
" axis_sizes={'x':nc, 'y':nc},\n",
" axis_resources={'x': 'nx', 'y':'ny',\n",
" 'key_x':'nx', 'key_y':'ny'})\n",
"def pnormal(key):\n",
" return jax.random.normal(key, shape=[nc])\n",
"\n",
"@partial(xmap,\n",
" in_axes={0:'x', 1:'y'},\n",
" out_axes=['x','y',...],\n",
" axis_resources={'x': 'nx', 'y': 'ny'})\n",
"def pfft3d(mesh):\n",
" # [x, y, z]\n",
" mesh = jnp.fft.fft(mesh)\n",
" mesh = lax.all_to_all(mesh, 'x', 0, 0) # [z, y, x]\n",
" mesh = jnp.fft.fft(mesh)\n",
" mesh = lax.all_to_all(mesh, 'y', 0, 0) # [z, x, y]\n",
" return jnp.fft.fft(mesh)\n",
"\n",
"@partial(xmap,\n",
" in_axes={0:'x', 1:'y'},\n",
" out_axes=['x','y',...],\n",
" axis_resources={'x': 'nx', 'y': 'ny'})\n",
"def pifft3d(mesh):\n",
" mesh = jnp.fft.ifft(mesh) \n",
" mesh = lax.all_to_all(mesh, 'y', 0, 0)\n",
" mesh = jnp.fft.ifft(mesh)\n",
" mesh = lax.all_to_all(mesh, 'x', 0, 0) \n",
" return jnp.fft.ifft(mesh).real"
]
},
{
"cell_type": "code",
"execution_count": 153,
"id": "b5378206-5dac-431d-9691-02366ed3470b",
"metadata": {},
"outputs": [],
"source": [
"@partial(xmap,\n",
" in_axes=(['x','y',...],\n",
" ['x'],\n",
" ['y'],\n",
" [...],[...],[...]),\n",
" out_axes=['x','y',...],\n",
" axis_resources={'x': 'nx', 'y': 'ny'})\n",
"def cwise_fn(kfield, kx, ky, kz, k, pk):\n",
" kk = jnp.sqrt((kx / boxsize * nc)**2 + (ky / boxsize * nc)**2 +\n",
" (kz / boxsize * nc)**2)\n",
" pkmesh = jc.scipy.interpolate.interp(kk, k, pk)\n",
" return kfield*(pkmesh*nc**3/boxsize**3)**0.5\n",
"\n",
"def get_initial_cond(cosmo, seed):\n",
" # Get real density field\n",
" linear = pnormal(jax.random.split(seed, nc*nc).reshape(nc,nc,-1))\n",
" lineark = pfft3d(linear)\n",
" \n",
" k = jnp.logspace(-4, 2, 256)\n",
" pk = jc.power.linear_matter_power(cosmo, k)\n",
" kvec = fftk([nc,nc,nc], symmetric=False)\n",
" \n",
" lineark = cwise_fn(lineark, kvec[0].squeeze(),kvec[1].squeeze(),kvec[2].squeeze(), k, pk)\n",
" \n",
" return pifft3d(lineark)"
]
},
{
"cell_type": "code",
"execution_count": 154,
"id": "24b1ee98-e3fa-471c-a44d-5aa6cbbaa49f",
"metadata": {},
"outputs": [],
"source": [
"key = jax.random.PRNGKey(42)\n",
"# keys = jax.random.split(key, 4).reshape([2,2,2])\n",
"\n",
"# We reshape all our devices to the mesh shape we want\n",
"devices = np.array(jax.local_devices()).reshape((2, 2))\n",
"\n",
"cosmo = jc.Planck15()"
]
},
{
"cell_type": "markdown",
"id": "f01f4d7b-39da-4479-8eb8-1ae9ab975ec4",
"metadata": {},
"source": [
"Ok, cool, now let's implement LPT"
]
},
{
"cell_type": "code",
"execution_count": 155,
"id": "5f50c23d-b269-493e-af6b-09388419ef4d",
"metadata": {},
"outputs": [],
"source": [
"@partial(xmap,\n",
" in_axes=([...]),\n",
" out_axes={0:'sx', 2:'sy'},\n",
" axis_sizes={'sx':2, 'sy':2},\n",
" axis_resources={'sx': 'nx', 'sy': 'ny'})\n",
"def pmeshgrid(x, y, z):\n",
" return jnp.stack(jnp.meshgrid(x,y,z),axis=-1)\n",
"\n",
"@partial(xmap,\n",
" in_axes=(['x','y','z'],\n",
" ['x'], ['y'], ['z']),\n",
" out_axes=(['x','y','z'],\n",
" ['x','y','z'],\n",
" ['x','y','z']),\n",
" axis_resources={'x': 'nx', 'y': 'ny'})\n",
"def papply_gradient_laplace(kfield, kx, ky, kz):\n",
" kk = (kx**2 + ky**2 + kz**2)\n",
" kernel = jnp.where(kk == 0, 1., 1./kk)\n",
" return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), \n",
" kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), \n",
" kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))) \n",
"\n",
"preshape = pjit(lambda x: x.reshape([2, nc//2, 2, nc//2]+list(x.shape[2:])), \n",
" in_axis_resources=PartitionSpec('nx','ny'),\n",
" out_axis_resources=PartitionSpec('nx', None, 'ny', None))\n",
"\n",
"pireshape = pjit(lambda x: x.reshape([nc, nc]+list(x.shape[4:])), \n",
" in_axis_resources=PartitionSpec('nx', None, 'ny', None),\n",
" out_axis_resources=PartitionSpec('nx','ny'))\n",
"\n",
"pcic_read = xmap(lambda mesh, pos: cic_read(mesh, pos.reshape(-1,3)).reshape(pos.shape[:-1]),\n",
" in_axes=({0:'sx',2:'sy'},\n",
" {0:'sx',2:'sy'}),\n",
" out_axes=({0:'sx',2:'sy'}),\n",
" axis_resources={'sx': 'nx', 'sy': 'ny'})\n",
"\n",
"pcic_paint = xmap(lambda mesh, pos, halo_size=halo_size: cic_paint(mesh, pos.reshape(-1,3)+jnp.array([halo_size,halo_size,0]).reshape([-1,3])),\n",
" in_axes=({0:'sx',2:'sy'},\n",
" {0:'sx',2:'sy'}),\n",
" out_axes=({0:'sx',2:'sy'}),\n",
" axis_resources={'sx': 'nx', 'sy': 'ny'})\n",
"\n",
"@partial(xmap,\n",
" in_axes=({0:'sx',2:'sy'},[...]),\n",
" out_axes={0:'sx',2:'sy'},\n",
" axis_resources={'sx': 'nx', 'sy': 'ny'})\n",
"def pad_mesh(mesh, halo_size=halo_size):\n",
" return jnp.pad(messh,[halo_size]*3)\n",
"\n",
"@partial(xmap,\n",
" in_axes=({0:'sx',2:'sy'}),\n",
" out_axes={0:'sx',2:'sy'},\n",
" axis_resources={'sx': 'nx', 'sy': 'ny'})\n",
"def halo_reduce(mesh, halo_size=halo_size):\n",
" for axis_ind, axis_name in enumerate(['sx', 'sy']): \n",
" # Split the array\n",
" left_margin, center, right_margin = mesh.split([2*halo_size, nc//2 ], axis_ind)\n",
" \n",
" # Perform halo exchange\n",
" left = lax.pshuffle(right_margin, perm=[1,0], axis_name=axis_name)\n",
" right =lax.pshuffle(left_margin, perm=[1,0], axis_name=axis_name)\n",
" if axis_ind==0:\n",
" mesh = mesh.at[:2*halo_size].add(left)\n",
" mesh = mesh.at[-2*halo_size:].add(right)\n",
" else:\n",
" mesh = mesh.at[:,:2*halo_size].add(left)\n",
" mesh = mesh.at[:,-2*halo_size:].add(right)\n",
" \n",
" # removing leftovers\n",
" return mesh[halo_size:-halo_size,halo_size:-halo_size]"
]
},
{
"cell_type": "code",
"execution_count": 171,
"id": "65edea80-9ea1-4300-842a-63ae81a5dddb",
"metadata": {},
"outputs": [],
"source": [
"with Mesh(devices, ('nx', 'ny')):\n",
" \n",
" initial_conditions = get_initial_cond(cosmo, key)\n",
" \n",
" # Create the particles\n",
" pos = pmeshgrid(jnp.arange(nc//2), jnp.arange(nc//2), jnp.arange(nc))\n",
" \n",
" # Take the FFT of the field\n",
" lineark = pfft3d(initial_conditions)\n",
" \n",
" # Apply the laplace kernel\n",
" kvec = fftk([nc,nc,nc], symmetric=False)\n",
" kforces = papply_gradient_laplace(lineark,\n",
" kvec[0].squeeze(), \n",
" kvec[1].squeeze(), \n",
" kvec[2].squeeze())\n",
" \n",
" # Inverse Fourier Transform\n",
" forces_x = pcic_read(preshape(pifft3d(kforces[0])), pos)\n",
" forces_y = pcic_read(preshape(pifft3d(kforces[1])), pos)\n",
" forces_z = pcic_read(preshape(pifft3d(kforces[2])), pos)\n",
" \n",
" # Read the forces at particle positions\n",
" dx = xmap(lambda a,b,c: jnp.array([a,b,c]),\n",
" in_axes=(['sx','tx','sy','ty','z',...],\n",
" ['sx','tx','sy','ty','z',...],\n",
" ['sx','tx','sy','ty','z',...]),\n",
" out_axes=['sx','tx','sy','ty','z',...],\n",
" axis_resources={'sx': 'nx', 'sy': 'ny'}\n",
" )(forces_x, forces_y, forces_z)\n",
" \n",
" x_final = xmap(lambda a,b:a+b,\n",
" in_axes=(['sx','tx','sy','ty','z',...],\n",
" ['sx','tx','sy','ty','z',...]),\n",
" out_axes=['sx','tx','sy','ty','z',...], \n",
" axis_resources={'sx': 'nx', 'sy': 'ny'}\n",
" )(dx,pos)\n",
" \n",
" # Painting final field\n",
" res = pcic_paint(jnp.zeros([2,256+halo_size*2,2,256+halo_size*2,512]), x_final)\n",
" res = halo_reduce(res)\n",
" res = pireshape(res).block_until_ready()"
]
},
{
"cell_type": "code",
"execution_count": 160,
"id": "acc4aaa7-9d45-4cbd-b560-86014de49922",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x15343c2cceb0>"
]
},
"execution_count": 160,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAJCCAYAAAA7hTjJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9yY5lSZYtiK0tzWluo6qmZuZN5EvmI6oIEvwA8gsIcMYpyQ8ocMAROeGAX1IDjvkFBfAbChyywCo+PryXGRnemJk2tzuNdBysLXLUAxkBFvAC5QnoATwz3Fzt6rnniGzZe+211pZSCt6v9+v9er/er/fr/Xq/3q/tMv9D38D79X69X+/X+/V+vV/v1+/tek+Q3q/36/16v96v9+v9er/+7HpPkN6v9+v9er/er/fr/Xq//ux6T5Der/fr/Xq/3q/36/16v/7sek+Q3q/36/16v96v9+v9er/+7HpPkN6v9+v9er/er/fr/Xq//uz6myVIIvK/FpH/VkT+nYj8X/5Wv+f9er/er/fr/Xq/3q/36z/1JX8LHyQRsQD+OwD/KwB/BPBfA/jflVL+m//kv+z9er/er/fr/Xq/3q/36z/x9bdCkP4XAP5dKeXfl1JWAP93AP+bv9Hver/er/fr/Xq/3q/36/36T3q5v9Hn/h2Af3rz738E8L/8izcx7ku/fwQASALMmiE5o1iDYgRSCiCCbAXZA8UCKEAx/Ae2AFkAAUwATNTPiQXIRMiKEcAARfhzAFAEgACSgWz1c8HPlswfy5Y/AwXaJPOzpQCSCqTo5xT+XLH1wwt/DgBShmR+h6L/GXofxQqS1z/UzzEJkFz0Z/n9i6l/UT/e/Pa+iug95e3P6+8yqUBSQRHh/Znf/j3U55j59+tz5Xf8s5dl3vw9fR7134t78+dlu6c/f9b1/t/+rOTtv9XnX++j6O80gb8/W6D4sn3/JPrMtt9Tv4tJ9fMKshNkt/2eeo/FbWtKkr5H4TpCefNcy/bf2n3VZ6m3s/3eAgkZkjJgTHv22QuKe/N5dU1lvuN/6bmglN/8TL2y48+b9NtnWn8/CtdX+476jm3gZ5o1Q2IGdJ9B9xmfn65tq2um1DUpSD2ft4mlPYvspb2n+oxg9JHV51O/T9a9Cb4P1Getfy/rvdY1Uf+xcwJyBvSZwshv9lTdYyilfZfsLd9t/V5Fn2d990ba/fEm+V3bvdc/1j0kNZ5YQe7M9pyM/Ga91+/Q1uKbtVHv1cS6L7e/V++b7022eGLactE4ps/LbM+q7h2TAInb9yl2i2O/+Vm9p7p+ROMO/7c+QwDFGWSNUSZu6yp13DeMtWi/p/gCcQXGMBilaCBBIAG/2ac13koCTCgwawKybkARwGzvj7Gcv1dK+c39Iev/NgbFGz4782Yd5y0u1+fQ9rnhfRSzrdF6tpgI2FXjpsFvzh4Tt5/97d4v24uqz1vXz9u1AiNtfckadU0LYG179zBoe6/dn/5Kk8q/GEshWwxpa+rNejTxTZwEf6aefb+Jc2/OkvZZ7YwDzFpg1gJYQXaC5Pnf6toz+n3596U+DiDpue7exLr2pf7sHKt/nNGeHTKYD5Rtr0jK2zPV59b+8ptzuX7PP3/fy5/++LWU8hn/wvW3SpDkX/iz3/TyROS/APBfAIC7+4D/0f/h/4TDPxUc/rTCThH2sqD0HmlwQCkozmD67HH5NxbrPeBP/JLXf0hAAezNwF8F7gp0p4LhOWP/TzeYJSKPHuHgkQarmwtIgyB1hgEawO07g7AH7MK7d1duyuwERoOYiYCbSwvYu19W+OcZ4WHA7YcOcRCYUNCfMuyUYOcEew0wS0Bx9ZQ3KN4i9RbhziMNBskL3JKRvLQDsH9N/N66seNoEEdB6vSQL5owgIsfBeguBd0pIe4YJEwssFPWBcfvsHxwvM9Y4KYCuxakThBHQdgzyHfnguwFyweBnQqG5wK3ZEgC1oPBeuSmaO9Sn6ldt802fxQsHzOyB9xVEI88JXZ/tDArg03qgTQWZAe4CRi+8vkBfA81KMUdE+OayJrE75wGnnLdCfCX8psEx18LN2sqCHuD8WvEcm+RvcBNmd95EKRRsNzXv8PPD3cF/bNg+FaQegAZGF4y7FowP1ikAXrIAzYUmPDbA7cmgv6c0J0CihPMjx3snDF9cog7oH8pLYDalZ9RrCCMTKJMKOiuGcOvCyRm7gdnUHqP9b5DsYL13uH6vUH/UuCWgtQBYSeQzLXgr7kFiOv3FtN3gnBXMHwRDF8LhpeE4dcFpTMwS4KZI+KhA6xA1gwpBeHgEe4cJBf03wLsFDD9sIMJGW7myRoHi/Xe4fIH09ZmMUDcAbkr/N93GUUKDv/B4e4/ZPhzQhoM5geD4mphU9r9hj3XupsL/LWgO2WMP99gv7yivJ6A7xnPSu9QrIWZV8i0IN/tmMicrkCMSD98xO0f9rh9tEi9wK6MDe6W4aYEsybY64p06FGMIPVMvMLBwqwFds2QVGAW3ctrRHEG4XHE7fsO61EwPGdkK+19mgjEXrg2ahKphVAauMa6c8H+p4D+lwu/hxWsn/aAALfvPWLP/VXMFnOK5ftd7xkD7ALkbtsL/ZNg90tGd8laDDCZDXsgHJnQ2IX3lzru3e5UYGdgeM3IFki9oH/hs0kd48h6ZzQR4/soAsSR8SHsuG/8We/RAdMnQTwUrD8GHB+vCMFhfhngvjkc/igYv3JdJv2O/pYxfA3o//EJmGbAOcA7pI9HhGOH0z908LeC/iVqbMjwLzMkJCAmSIgoziLfjZi/Gxn/ZEv8WjFtBbfPhvH9xv2SO0Ecuea618L4D34/yQW7LxkmMEZe/o3FesfPGn8Fn9PCNZu9QFKBv/HZm7XA3yKS53rqvk0wcwTWgPDDPdJo4a4RZk0wrzfIGlB2A9ehNZCUsX7oEQ4W69FAEuCnjDjwvXaXjO6UkK20ZGM9GMSdJuoJcAvfiV3LBhjovfIzGSPPf2+w3uvPzkzI/AUYnhizs54lkhnv/CVj+PkKAJj+sMfts0P2jDnDU+TZd1shISEdB6z3fktQHZ9r6gS37yyWDzwD3EXgZr4Xs25Jkpu5B7MjQGJDgZ0z+qcF9vmGMnZI+w6pt3CXFbl3WD56SGTs9+cASQXrhw7TR4f5k+D8n3NhmElgJ8G/+7/+n//jX8hj/mYJ0h8B/P2bf/83AP709gdKKf8lgP8SAMYf/r70Tzww1nsH2xvI0cOuGWbNSL3F/Mnj9okB2J+4sYsB0q8WRbjx++eC/qQVUN4yaEkFdsmaYQLuGuEmQfIG2RvYNcPfLKZHy0qoogIC+CsTjOQ14QgMLtkJwsEh9XtMHy2WewMbCpYHg+mzwcP/F7AhA84AC4BaOSZWwcZKy9YlF8Sef3+645+byKCUem6I9U6wfACKLXBXQXcqkMgAmToGZ7sCcWc0I+eha1JBHA0kCdLwW/QBKPpsuImys1geBOXGhZlP/Nn5g8BNFn7SpKOiQpYJjBQmUjXhmj8YuBlYs6AcAqKzkEVgF4EkYPjGD1geBYtjxdE/CfqX0qrS7PmOUy9MUhwPlvqOUyeYPjFxcrcCd+NzdHP5LXpiALvwGQAMxskzGQwHHlj1Wh6A3LNSzQ6Ig64FC1z+YPlM9DD3U0HYsdJ1EzctBAgHi+Woz7kYuNkga9VrUsH4FLEk25ATAEwSDIM1MpM7GwA7F9jrCpkDkBIkWZTes+ozrNrcrSB7IEfeV3aCNPJzTajJNZ9T3PNwsxOfg2Qg7RxMZBKAmGGngDR6Vu9ClKQYMGjFDHOaMIpAloB0HJA7g7jndzaRSZFZmfC2QmQokFVQ+oK4Y/Lgr7yvuOP7tSt/3s5Ad22QDg/RDPSvihJYCzkcuD1DhITYKkkYA4kZpXPAYYTcZkjiQb98ZAU5PLG67Z8X2NeJ6G6IMKcb0qc7xP2gz45FghSBjRtytAUwrivsBWE08FNuyTkK4BMT1CLckxX1sIH7200MCPFu4H8LiYe4MzxEynawVZTIBMDNLCzKDggHIO0K0lBgZ4FdAH/j3wn3LPjiyPfiLoDPBd2lIFvB9B3XfrHC+HeKWO9
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"figure(figsize=[10,10])\n",
"imshow(res.sum(axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": 141,
"id": "7fba6185-ce16-471d-8043-c0d65d4cdaf4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x1535400f9700>"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkgAAAJCCAYAAAA7hTjJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9ya+u25bmB/1m9RZfvcpdn+Kec2/EzYiMyCSDtFO2hUCAgAYgQQPo0XGLBk33+D/cQIIOSIgOki3cQbbstOUswpk342bELU6167VX8dVvNSsa411r3whlBCDlFTelPaWjffY5a33FfGcxxjOe5xkq58yn8Wl8Gp/Gp/FpfBqfxqfxcej/f3+AT+PT+DQ+jU/j0/g0Po3ftfEpQPo0Po1P49P4ND6NT+PT+CvjU4D0aXwan8an8Wl8Gp/Gp/FXxqcA6dP4ND6NT+PT+DQ+jU/jr4xPAdKn8Wl8Gp/Gp/FpfBqfxl8ZnwKkT+PT+DQ+jU/j0/g0Po2/Mn5rAZJS6n+klPqFUurXSqn/4Lf1Pp/Gp/FpfBqfxqfxaXwa/7qH+m34ICmlDPBL4H8AvAb+MfC/zjn/y3/tb/ZpfBqfxqfxaXwan8an8a95/LYQpL8P/Drn/G3OeQD+r8D/7Lf0Xp/Gp/FpfBqfxqfxaXwa/1qH/S297jPg1W/8/TXwb/21H2I5yebsBJJCDQqVIRvINoMCVEbpjAJyVmSvIQM6y+9EMD1kC8kCGXmNMfxTUf6bvNb42kr+e3YZFRXZgLKJHDRE+dFcJJSCnIEwvieg0v3rZPQg76/S+N58fC8dQEX5payVvKh8HcgQHaRi/Fk//qqS76H9+Dk1qAC5yCiTyVFBHOdIyWtpDyRITr6PKSLRG8hgWnnfrHn4/FmPf9c8zC+M8xnHzxkV2IxzgcJE+mgJ3oAfv4SRedODfM5UZNBQFAGjEu3gMCaRkiYnBQm0S1TWU5lAGx1tX8h86EwOCuUVufiNz+K1/GkyWo/z2BqyzSibUSqTs3weoxO19VidaIMjZkXOitIGjn0JCXl9BRiZLx0+rsHkQBfx4X0UmRANdBoVIFXyO+gsc5o/ThuMc2DkT5nX8QHlj/OOzehOfZx/l1E6yfxkBUl9/O4qo8a/5oys86DI5n4hy9rL5v7988P3kr0x/vv98zTyIbRNpKRkPSP7xm0D5ESqHaEa90L6uFbu98/9fH18T/n9WHzcU2jAJkoXyCgGb8BrVBr3iBvnxmQKFwlRY3QiZUVqrez5cV/ffwcV1MPz0gFUyqgEw0w9vG42455RoCPokMkaYqGIxW/Mh/p4NtzvORVkH+ZyfFCRj8/iNx4l92sjqoe50UUkRY0an4UKQD2eIyqPKWiGQWOG8bXu10sp80BUco6EcV2ocZ5cgqg/foik/tJzwY5/CTK/3J95evys4/9WXpGdrAE1LtqckS8VFNrL2aWSzF1y8j1UHM/UYvz7/dpSH9eAGtdEmsibORtlCpPGmUihI21whGBQvfq4P/T9YcTHP6NCAbp/WOLyvO4/0/0Zq37jd8af+3gmy/+yzfgoS9mXtpOfS8V4/t3vFYXsS5PkbokKNc5TGs9QAOUSqEzuzcO8YDIE9Zfm5f5zZSP/X9tE7gwkWV+6VX/p3NE+o2Ima0Ua1+r9d8BlrJE9gteydsY7UZsk667/uE7lg473wvh4sbIO9KDkGefxrNCyLigTWstEpKTkLPqNs0epTIpa3i/oh/tWpY/3a7Z8PAyDknkxcn7rQZ53qvLHNTzOr1ZZ7pSkHu7z+7v446b7jbUSFLb9+N7J/cb5NI7kfmN9je8Hcu7dr/0YNcYkpm6gj5b9Lz/c5Jwv+FeM31aApP4V/+0v1fKUUv8+8O8DmPMlj/+D/z3oTPHBPhz8yWX8WeDkyY5F1XN7nHDY1LirArdTNM8DahLJjWH6g8UdoT+RCauvZCH5qSJMYVhlsgJ3UCSbiZNMqjKrFxs2P6xQq4F0dKjJ/Y5X2Mo/XJj9ocRUgRQ09nUph+8sUb81ZAXFDspNpjuT1el2mflrT6w05cYTS8PhiSNZmNzKE929sOiQaR8p+vNIeWPoLgOP/kvN+vcV/kVPDhrVGcxJT2gt9tphu/E7VHIJuL1GewgTmS+yorgxuIOivsqEiXymYS4LKJUZFcHPMmkVUHuLOuvROhM6i7IJfV0QZxG3GIhB8/hiS8qKXVPRrGvsncM2imGVcE+P9McCvXYUL47kDCkpwtWEYqtQPz3Az+dkBf1nPX/w5VsmduBnb5/iB4t1kWFbYjeWWKWHQ69+Z2heBPTC83vPrvju5oxuW2JvnFwiRSZNInoSWK2OfL5c04SCN9slhw9Tzp9t+fHJNT/sT3j78gzdGKoXe9Q/Xj6sxvZRIp14TBkpS08IBn9Tk11C1wFuSrLN5GmUi+xgcDuZ71RkbKOIJYRpRgVw+/HCcRBmmeGJvDYqU5ae/rs51bXGdNCfQv8oyKGgMmoMJNGQJ5Fy1jM0BerOyQWUYPLa4meZsEjMf23QAfwMunOZN9NJEBhmSYLNBHrhqScDXVuQvJYA+rYkm8zkleXxP+op3+1Rx5b+ywve/YOK7jKhBzW+JlQ3CtNlYqnQITN9n7Btorxu+P5/upQ9d51pHivC1y3nJ3uGYGj7gu6uYv5Lh20yu68y6dwzWzV8cbKmMh6tMk0o+Pb2jOP7KXruSb1BtXJTZJsxB8PslXyGYaVwe1nD7phpHmnKdWb6PrD+icO2mfoukRU0jwzDXE6kWGVZ87OI2cvBvPwVJKvY/7sNZelpbie4G0uxG+dxIgvFr5IE5klx+eUtHz4s0beOYq0pN2D6zN0fZnKRWDzZ03aOFA1pXaAGhW0UxU5hOvBT8MuM235MdNrnkYv/WtNcKrqLTHWjOHwZZE2Ml5pZespqoO8KYm+o5j0pKfz7Ce4wBmgRwucdWmei15j3JbZV+FkiXQxM5j3NviS3FmzCrB22lbPAtopYZuIs4TYGPYznhRuDKgNup6luwLaZbCBUCr+A7nFk+XzLsu7YdSXLuuP2OGFe9XTest1Nydcl1QdNsjCsEmkaxyBD49aG+oNiWEKoMrOXCtvK2agSFBt53s0T9ZA43ieRyWbKtWL6LtGdaPZfJcobCSi684TpJQCdvM/0JzKfyYJroL3I6K8P1KVnczXHbixhHsEl3I3DHhXuALs/6tFbh2kUYTrOxyxIgLx3FGuD6RRhInvTbjVur2ifR6p3hvZZQM89+nWF2ysWPySKQ6I5M3QXY6Kp4fhZxF62hJsau9PoAMNpwjSasJDPNVm1tO9nFLca3Sv6i0R5q7FHCBPonkRykSjfOlQCt4dYyTOOdaa61jRfeNyNJdYZLnrKyqNUpu8dClgtGkobuDtMaPcl7qogPu/IQVN+Lwmn7cYA1IJfJnSvcAcJ9LuLRD4fyI2lem/RPRRbSCV0p5niD7fMqp6cFdNi4NtXFzBoVK9h5SVpToqi8hRFIEZN3xWoVxXFVsKL6btMe67oTzNhlnBbTTayz1WGWMuZ+JBY2oyuA88uN7y7XRL2jtPTjv3/8v/ww18XyPy2AqTXwIvf+Ptz4O1v/kDO+T8E/kOA6tmLXH4wDKuEPcoi8ys5oFWv2e4mdIP7S9HmsMxQJpRJ5FJRrjPNE0V/EdGdIjm5hIZlJkzl4EJn4lnGlJH8oSRXkUNTStSdFXZnSDNPGgyTZUtKGmsjhw9TSAo3jzy+WFM8i1wfp6zfL2g+y5QfDP2Zwgxw8itPvzD4iQKlmLw8kipLri06ZmKlaM6MZL4ZhoUEMbHShDpTv7MMs0y2kJPCTQe8KgjbQg75WQItqIYOMhexzsQairUia0ucJ6ZvoNglQq0YFnKBpwJIUG4kCwg1cvhWEZUhHB16b8inHhUVqoroX00wQfG2tVQnHTFolEuoALHM5JOBogg4FzlERc5wvjjSDI71pMRHg/52Rp4nUpUxhWSaTShwLlIUgUnhuY2avDOUd4ZhkUjziF9qKCUbv2mmDL2lel2gAvhFZv6tJhWa9sKyAS6mRxpf0BxKVK+5ebvk0JZ4b8Al8koQre6LgDl
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"figure(figsize=[10,10])\n",
"imshow((0+initial_conditions).real.sum(axis=-1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b03d0661-e0ba-48d0-a225-1ae71ad5be15",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Jax",
"language": "python",
"name": "jax"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}