JaxPM_highres/notebooks/04-MultiGPU_PM_Solvers.ipynb

380 lines
3.7 MiB
Text
Raw Permalink Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# **Multi-GPU Particle Mesh Simulation with Advanced Solvers**\n",
"\n",
"<a href=\"https://colab.research.google.com/github/DifferentiableUniverseInitiative/JaxPM/blob/main/notebooks/04-MultiGPU_PM_Solvers.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"EQX_ON_ERROR\"] = \"nan\"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax_cosmo as jc\n",
"\n",
"from jaxpm.kernels import interpolate_power_spectrum\n",
"from jaxpm.painting import cic_paint_dx\n",
"from jaxpm.pm import linear_field, lpt, make_diffrax_ode\n",
"from functools import partial\n",
"from diffrax import ConstantStepSize, LeapfrogMidpoint,Dopri5 , PIDController , ODETerm, SaveAt, diffeqsolve"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> **Note**: This notebook requires 8 devices (GPU or TPU).\\\n",
"> If you're running on CPU or don't have access to 8 devices,\\\n",
"> you can simulate multiple devices by adding the following code at the start **BEFORE IMPORTING JAX**:\n",
"\n",
"```python\n",
"import os\n",
"os.environ[\"JAX_PLATFORM_NAME\"] = \"cpu\"\n",
"os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n",
"```\n",
"\n",
"**Recommended only for debugging**. If used, you must probably lower the resolution of the mesh."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"assert jax.device_count() >= 8, \"This notebook requires a TPU or GPU runtime with 8 devices\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setting Up Device Mesh and Sharding for Multi-GPU Simulation\n",
"\n",
"This cell configures a **2x4 device mesh** across 8 devices and sets up named sharding to distribute data efficiently.\n",
"\n",
"- **Device Mesh**: `pdims = (2, 4)` arranges devices in a 2x4 grid. `create_device_mesh(pdims)` initializes this layout across available GPUs.\n",
"- **Sharding with Mesh**: `Mesh(devices, axis_names=('x', 'y'))` assigns the mesh grid axes, which allows flexible mapping of array data across devices.\n",
"- **PartitionSpec and NamedSharding**: `PartitionSpec` defines data partitioning across mesh axes `('x', 'y')`, and `NamedSharding(mesh, P('x', 'y'))` specifies this sharding scheme for arrays in the simulation.\n",
"\n",
"More info on Sharding in general in [Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from jax.experimental.mesh_utils import create_device_mesh\n",
"from jax.experimental.multihost_utils import process_allgather\n",
"from jax.sharding import Mesh, NamedSharding\n",
"from jax.sharding import PartitionSpec as P\n",
"\n",
"all_gather = partial(process_allgather, tiled=False)\n",
"\n",
"pdims = (2, 4)\n",
"devices = create_device_mesh(pdims)\n",
"mesh = Mesh(devices, axis_names=('x', 'y'))\n",
"sharding = NamedSharding(mesh, P('x', 'y'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@partial(jax.jit , static_argnums=(2,3,4,5))\n",
"def run_simulation_with_fields(omega_c, sigma8,mesh_shape,box_size,halo_size , snapshots):\n",
" mesh_shape = (mesh_shape,) * 3\n",
" box_size = (box_size,) * 3\n",
" # Create a small function to generate the matter power spectrum\n",
" k = jnp.logspace(-4, 1, 128)\n",
" pk = jc.power.linear_matter_power(\n",
" jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
" pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n",
"\n",
" # Create initial conditions\n",
" initial_conditions = linear_field(mesh_shape,\n",
" box_size,\n",
" pk_fn,\n",
" seed=jax.random.PRNGKey(0),\n",
" sharding=sharding)\n",
"\n",
"\n",
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
"\n",
" # Initial displacement\n",
" dx, p, f = lpt(cosmo,\n",
" initial_conditions,\n",
" a=0.1,\n",
" order=2,\n",
" halo_size=halo_size,\n",
" sharding=sharding)\n",
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = LeapfrogMidpoint()\n",
"\n",
" stepsize_controller = ConstantStepSize()\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
" dt0=0.01,\n",
" y0=jnp.stack([dx, p], axis=0),\n",
" args=cosmo,\n",
" saveat=SaveAt(ts=snapshots),\n",
" stepsize_controller=stepsize_controller)\n",
" ode_fields = [cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding) for sol in res.ys]\n",
" lpt_field = cic_paint_dx(dx , halo_size=halo_size, sharding=sharding)\n",
" return initial_conditions, lpt_field, ode_fields, res.stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Large-Scale Simulation Across Multiple Devices\n",
"\n",
"In this cell, we run a large simulation that would not be feasible on a single device. By distributing data across multiple devices, we achieve a higher resolution (`mesh_shape = 1024` and `box_size = 1000.`) with effective boundary handling using a `halo_size` of 128.\n",
"\n",
"We gather initial conditions and computed fields from all devices for visualization.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"45.6 s ± 11.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8QAAAHsCAYAAACkFRcHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9d7hlRZmo/1bVSjuefE736ZyIImgrjog0KEpQGVTgOo4KjogiDuqoM4Y7Cgo6ZkcdEXUMo453BBT0/swjKg4GUJLEpumcTtx575Wqfn+sfbYcuoEGGwm33uc5T/euXWuvWrXWqu+r+kIJY4zBYrFYLBaLxWKxWCwWi8VisVgsFovFYrFYLJYnGPLRboDFYrFYLBaLxWKxWCwWi8VisVgsFovFYrFYLI8E1iBusVgsFovFYrFYLBaLxWKxWCwWi8VisVgslick1iBusVgsFovFYrFYLBaLxWKxWCwWi8VisVgslick1iBusVgsFovFYrFYLBaLxWKxWCwWi8VisVgslick1iBusVgsFovFYrFYLBaLxWKxWCwWi8VisVgslick1iBusVgsFovFYrFYLBaLxWKxWCwWi8VisVgslick1iBusVgsFovFYrFYLBaLxWKxWCwWi8VisVgslick1iBusVgsFovFYrFYLBaLxWKxWCwWi8VisVgslick1iBusVgsFovFYrFYLBaLxWKxWCwWi8VisVgslick1iBueVwihOCCCy7Yp7rLly/nrLPOesjn2LRpE0IIvvKVrzzkYx8L/PznP0cIwc9//vNe2VlnncXy5cv36fgLLrgAIcQj07jHEY/mc/BQ7pfFYrFYHnvsTRbvK1/5ylcQQrBp06YHrftwdZ1HkkerTQ+l3ywWi8Xy2OO6667jqKOOolAoIITgxhtv/LPmpvs6p3q8z/8fjD9HJ/lzOfbYYzn22GP/4ue1WCwWy18OK78fGey6tMWyf7EGccujwtxi5fXXX79ffu/aa6/lggsuoFKp7Jffezjs3r2bt73tbRx00EHk83kKhQJr167loosuelTb9UC0Wi0uuOCCR2VSbMnYsWMHF1xwATfeeOOj3ZQ9+Pd//3cOPvhggiBgzZo1fPrTn360m2SxWCz7pEPMTRrn/pRSLF26lBe/+MW98fass86aV+f+/h7IqDs3wd/b3+c+97n9fOWWe/OBD3yAK6+88tFuxh5ce+21HH300eTzeRYsWMD5559Po9F4tJtlsVgsjwviOOb0009nZmaGT3ziE3zta19j2bJlj3azHhQ7b4LbbruNCy644DHnkKa15sMf/jArVqwgCAKe/OQn881vfvPRbpbFYrE8oXg8yu9LLrmE008/naVLlz7ovP+JjF2Xtvy/hvNoN8BieTi0220c50+P77XXXsuFF17IWWedRX9//7y6d955J1I+sr4f1113HSeffDKNRoNXvOIVrF27FoDrr7+ef/mXf+GXv/wlP/7xjx/RNuwLX/jCF9Ba9z63Wi0uvPBCgD08tv/3//7fvOMd7/hLNu8xybJly2i327iu+4j8/o4dO7jwwgtZvnw5RxxxxLzv7nu//pJceumlvP71r+elL30p//AP/8A111zD+eefT6vV4p/+6Z8elTZZLBbLQ+Vv/uZvOPnkk0nTlNtvv51LLrmEH/zgB/zmN7/hda97Hccff3yv7saNG3nPe97DOeecw7Of/exe+apVqx70PJdccgnFYnFe2TOe8QxWrVpFu93G87z9d1GPEx5p/esDH/gAp512Gqeeeuq88le+8pW87GUvw/f9R+zc98eNN97Ic5/7XA4++GA+/vGPs23bNj760Y+yfv16fvCDH/zF22OxWCyPNzZs2MDmzZv5whe+wNlnn90rfyzPTR8v86ZjjjnmEdVJbrvtNi688EKOPfbYPaLJHs21kHe/+938y7/8C6997Wt5+tOfzlVXXcXLX/5yhBC87GUve9TaZbFYLE8kHo/y+0Mf+hD1ep0jjzySnTt3PtrNuV/suvRjW7+yPP6wBnHL45IgCPa57iO9IFqpVHjxi1+MUoobbriBgw46aN73F198MV/4whce0TbsKw9FeDqOM8/p4ImCMYZOp0Mul9un+kKIh/S87U8eKWXnwWi327z73e/mBS94AZdffjkAr33ta9Fa8/73v59zzjmHgYGBR6VtFovF8lB46lOfyite8Yre52c961mccsopXHLJJVx66aU885nP7H13/fXX8573vIdnPvOZ847ZF0477TSGh4f3+t2jJUP2N81mk0KhsM/1Hw2DNIBSCqXUo3Lud73rXQwMDPDzn/+ccrkMZKnjX/va1/LjH/+Y5z//+Y9KuywWi+XxwsTEBMAeTu6P1bnpozlv0loTRdE+6xlSykdNJ3m0HAO3b9/Oxz72Mc477zw+85nPAHD22Wezbt063v72t3P66ac/ajqDxWKxPJF4vMlvgF/84he96PD7Orc/kth16QfHrktbHklsynTLY4azzjqLYrHI9u3bOfXUUykWi4yMjPC2t72NNE3n1b33HuIXXHABb3/72wFYsWJFL1XpXKqu++5hOTMzw9ve9jYOO+wwisUi5XKZk046iZtuuulhtfvSSy9l+/btfPzjH9/DGA4wNjbG//7f/3te2Wc/+1kOPfRQfN9nfHyc8847b4+06sceeyxPetKTuO222zjuuOPI5/MsWrSID3/4w3ucY9u2bZx66qkUCgVGR0d5y1veQhiGe9S7994fmzZtYmRkBIALL7yw12/37tf77vOSJAnvf//7WbVqFb7vs3z5ct71rnftca7ly5fzwhe+kF/96lcceeSRBEHAypUr+Y//+I959eI45sILL2TNmjUEQcDQ0BBHH300P/nJT/bs6Hsxly73l7/8Ja973esYGhqiXC7zqle9itnZ2b225Uc/+hFPe9rTyOVyXHrppQDcc889nH766QwODpLP5/mrv/or/r//7/+bd/z97dVyxx13cNpppzE4OEgQBDztaU/ju9/97h5trVQqvOUtb2H58uX4vs/ixYt51atexdTUFD//+c95+tOfDsCrX/3q3j2YO9fe9mppNpu89a1vZcmSJfi+z4EHHshHP/pRjDHz6gkheOMb38iVV17Jk570JHzf59BDD+WHP/zhA/YtwNVXX8309DRveMMb5pWfd955NJvNPfrIYrFYHi885znPAbJo8L8E97df529/+1tOPPFE+vr6yOfzrFu3jv/5n/950N8zxnDRRRexePFi8vk8xx13HLfeeus+tWVOnn30ox/lE5/4BMuWLSOXy7Fu3Tr++Mc/zqs7p5Nt2LCBk08+mVKpxN/+7d8C+y6H9raHeKVS4c1vfnPv2NWrV/OhD31oD69zrTX/+q//ymGHHUYQBIyMjHDiiSf20uQLIWg2m3z1q1/dI7X9/e0hvr91r/tSq9X4yU9+wite8YqeMRzgVa96FcVikW9961sP+hsWi8Xy/zJnnXUW69atA+D0009HCNHLYnZ/e5B+/etfZ+3ateRyOQYHB3nZy17G1q1bH/RclUqFs846i76+Pvr7+znzzDMf1jZnf+68ae667rjjDs444wzK5TJDQ0O86U1votPpzKs7N7/7xje+0ZNnc3O7G264gZNOOolyuUyxWOS5z30uv/nNb+Yd/+fqJNu3b+c1r3kN4+Pj+L7PihUrOPfcc4miiK985SucfvrpABx33HE92Tx3rr3tIT4xMcFrXvMaxsbGCIKAww8/nK9+9avz6txbd/n85z/fW4d4+tOfznXXXfeAfQtw1VVXEcfxvPsjhODcc89l27Zt/PrXv37Q37BYLBbLA/N4lN+QRV4/3P3N7bq0XZe2PH55bLroWP6fJU1TTjjhBJ7xjGfw0Y9+lJ/+9Kd87GMfY9WqVZx77rl7PeYlL3kJd911F9/85jf5xCc+0YvQmjP23pd77rmHK6+8ktNPP50VK1awe/duLr30UtatW8dtt93G+Pj4Q2rzd7/7XXK5HKeddto+1b/gggu48MILOf744zn33HO58847ueSSS7juuuv4n//5n3neV7Ozs5x44om85CUv4YwzzuDyyy/nn/7pnzj
"text/plain": [
"<Figure size 2000x500 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from jaxpm.plotting import plot_fields_single_projection\n",
"\n",
"mesh_shape = 1024\n",
"box_size = 1000.\n",
"halo_size = 128\n",
"snapshots = (0.5 , 1.0)\n",
"\n",
"initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)\n",
"ode_fields[-1].block_until_ready()\n",
"%timeit initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots);ode_fields[-1].block_until_ready()\n",
"\n",
"initial_conditions_g = all_gather(initial_conditions)\n",
"lpt_field_g = all_gather(lpt_field)\n",
"ode_fields_g = [all_gather(p) for p in ode_fields]\n",
"\n",
"fields = {\"Initial Conditions\" : initial_conditions_g , \"LPT Field\" : lpt_field_g}\n",
"for i , field in enumerate(ode_fields_g):\n",
" fields[f\"field_{i}\"] = field\n",
"plot_fields_single_projection(fields,project_axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This simulation runs in **45 seconds** (less than half a second per step), which is impressive for a setup with **over one billion particles** (since \\( 1024^3 \\approx 1.07 \\) billion). This performance demonstrates the efficiency of distributing data and computation across multiple devices.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Comparing ODE Solvers: Leapfrog vs. Dopri5\n",
"\n",
"Next, we compare the **Leapfrog** solver with **Dopri5** (an adaptive Runge-Kutta method) to observe differences in accuracy and performance for particle evolution.\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5.04 s ± 9.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"mesh_shape = 512\n",
"box_size = 512.\n",
"halo_size = 64\n",
"snapshots = (0.5 , 1.0)\n",
"\n",
"initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)\n",
"ode_fields[-1].block_until_ready()\n",
"%timeit initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_fields(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots);ode_fields[-1].block_until_ready()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.44 s ± 8.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"Solver Stats : {'max_steps': Array(4096, dtype=int32, weak_type=True), 'num_accepted_steps': Array(12, dtype=int32, weak_type=True), 'num_rejected_steps': Array(0, dtype=int32, weak_type=True), 'num_steps': Array(12, dtype=int32, weak_type=True)}\n"
]
}
],
"source": [
"mesh_shape = 512\n",
"box_size = 512.\n",
"halo_size = 64\n",
"snapshots = (0.5, 1.0)\n",
"\n",
"@partial(jax.jit , static_argnums=(2,3,4,5))\n",
"def run_simulation_with_dopri(omega_c, sigma8,mesh_shape,box_size,halo_size , snapshots):\n",
" mesh_shape = (mesh_shape,) * 3\n",
" box_size = (box_size,) * 3\n",
" # Create a small function to generate the matter power spectrum\n",
" k = jnp.logspace(-4, 1, 128)\n",
" pk = jc.power.linear_matter_power(\n",
" jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)\n",
" pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)\n",
"\n",
" # Create initial conditions\n",
" initial_conditions = linear_field(mesh_shape,\n",
" box_size,\n",
" pk_fn,\n",
" seed=jax.random.PRNGKey(0),\n",
" sharding=sharding)\n",
"\n",
"\n",
" cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)\n",
"\n",
" # Initial displacement\n",
" dx, p, f = lpt(cosmo,\n",
" initial_conditions,\n",
" a=0.1,\n",
" order=2,\n",
" halo_size=halo_size,\n",
" sharding=sharding)\n",
"\n",
" # Evolve the simulation forward\n",
" ode_fn = ODETerm(\n",
" make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))\n",
" solver = Dopri5()\n",
"\n",
" stepsize_controller = PIDController(rtol=1e-5,atol=1e-5)\n",
" res = diffeqsolve(ode_fn,\n",
" solver,\n",
" t0=0.1,\n",
" t1=1.,\n",
" dt0=0.01,\n",
" y0=jnp.stack([dx, p], axis=0),\n",
" args=cosmo,\n",
" saveat=SaveAt(ts=snapshots),\n",
" stepsize_controller=stepsize_controller)\n",
" ode_fields = [cic_paint_dx(sol[0], halo_size=halo_size, sharding=sharding) for sol in res.ys]\n",
" lpt_field = cic_paint_dx(dx , halo_size=halo_size, sharding=sharding)\n",
" return initial_conditions, lpt_field, ode_fields, res.stats\n",
"\n",
"initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_dopri(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots)\n",
"ode_fields[-1].block_until_ready()\n",
"%timeit initial_conditions , lpt_field , ode_fields , solver_stats = run_simulation_with_dopri(0.25, 0.8 , mesh_shape, box_size , halo_size , snapshots);ode_fields[-1].block_until_ready()\n",
"\n",
"print(f\"Solver Stats : {solver_stats}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8QAAAH5CAYAAAD3DoQvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOy9ebhlRX2o/VbVGvd0pj49jzSjgJi0YkSkQVEElRAFrslVwagYgyF61Xs1uYkQhyRqolETxCQOUW9uIsQh94vRGDFiNAYUcUCRoRvo8fQZ97jGqu+P2nufc+gGuhEETL3P0w+ctddQq6Zfrd9UwhhjcDgcDofD4XA4HA6Hw+FwOBwOh8PhcDgcDofj5wz5aBfA4XA4HA6Hw+FwOBwOh8PhcDgcDofD4XA4HI5HAmcQdzgcDofD4XA4HA6Hw+FwOBwOh8PhcDgcDsfPJc4g7nA4HA6Hw+FwOBwOh8PhcDgcDofD4XA4HI6fS5xB3OFwOBwOh8PhcDgcDofD4XA4HA6Hw+FwOBw/lziDuMPhcDgcDofD4XA4HA6Hw+FwOBwOh8PhcDh+LnEGcYfD4XA4HA6Hw+FwOBwOh8PhcDgcDofD4XD8XOIM4g6Hw+FwOBwOh8PhcDgcDofD4XA4HA6Hw+H4ucQZxB0Oh8PhcDgcDofD4XA4HA6Hw+FwOBwOh8Pxc4kziDscDofD4XA4HA6Hw+FwOBwOh8PhcDgcDofj5xJnEHc8bhBCcOWVVx7WuZs3b+bSSy894mfs3LkTIQQf+9jHjvjaxwJf/epXEULw1a9+dXjs0ksvZfPmzYd1/ZVXXokQ4pEp3OOIR7MfHEl7ORwOh+OxyaHk8eHysY99DCEEO3fufNBzH+p655Hk0SrTkdSbw+FwOB573HjjjZx22mlUq1WEEHz3u9/9qb5PD/e76vGuA3gwfpo1yU/LmWeeyZlnnvkzf67D4XA4fnY4+f3I4HTTDscjgzOIO35mDBSVN91008Nyv2984xtceeWVzM/PPyz3eyjs37+fN77xjRx//PFUKhWq1Srbtm3j7W9/+6Nargei2+1y5ZVXPiofxA7Lnj17uPLKK/nud7/7aBflIP76r/+aE044gSiKOOaYY/jABz7waBfJ4XA4gMNbRww+Ggf/lFJs3LiRX/mVXxnOuZdeeumyc+7v3wMZdQcf+If696EPfehhfnPHUt75znfy2c9+9tEuxkF84xvf4PTTT6dSqbB69WquuOIK2u32o10sh8PheFyQ5zkXXXQRs7OzvPe97+UTn/gEmzZterSL9aC4bye49dZbufLKKx9zDmlaa971rnexZcsWoijiiU98In/7t3/7aBfL4XA4fq54PMrvq6++mosuuoiNGzc+6Hf/zzNON+34r4r3aBfA4Thcer0enrfYZb/xjW9w1VVXcemllzI6Orrs3Ntuuw0pH1l/jxtvvJHzzjuPdrvNS17yErZt2wbATTfdxB/90R/xta99jS996UuPaBkOh7/8y79Eaz38u9vtctVVVwEc5K39v//3/+bNb37zz7J4j0k2bdpEr9fD9/1H5P579uzhqquuYvPmzTzpSU9a9tt92+tnyTXXXMNv/MZv8KIXvYj/8T/+BzfccANXXHEF3W6X//W//tejUiaHw+F4KPzqr/4q5513HmVZ8qMf/Yirr76aL3zhC/zHf/wHr371qzn77LOH5+7YsYPf//3f57LLLuMZz3jG8PjWrVsf9DlXX301tVpt2bGnPvWpbN26lV6vRxAED99LPU54pNdg73znO7nwwgu54IILlh1/6Utfyotf/GLCMHzEnn1/fPe73+VZz3oWJ5xwAn/6p3/Krl27eM973sPtt9/OF77whZ95eRwOh+Pxxp133sndd9/NX/7lX/LKV75yePyx/H36ePl2OuOMMx7RNcmtt97KVVddxZlnnnlQNNmjqQ/53d/9Xf7oj/6IV73qVTzlKU/hc5/7HL/2a7+GEIIXv/jFj1q5HA6H4+eJx6P8/uM//mNarRannnoqe/fufbSLc7843fRje33lePziDOKOxw1RFB32uY+0MnR+fp5f+ZVfQSnFzTffzPHHH7/s93e84x385V/+5SNahsPlSASn53nLnA5+XjDGkCQJcRwf1vlCiCPqbw8nj9RC58Ho9Xr87u/+Ls973vO49tprAXjVq16F1pq3ve1tXHbZZYyNjT0qZXM4HI4j5Rd/8Rd5yUteMvz76U9/Oueffz5XX30111xzDU972tOGv9100038/u//Pk972tOWXXM4XHjhhaxYseKQvz1acuThptPpUK1WD/v8R8MgDaCUQin1qDz7d37ndxgbG+OrX/0qjUYDsKnjX/WqV/GlL32J5zznOY9KuRwOh+PxwtTUFMBBju6P1e/TR/PbSWtNlmWHvc6QUj5qa5JHyzFw9+7d/Mmf/AmXX345H/zgBwF45Stfyfbt23nTm97ERRdd9KitGRwOh+Pniceb/Ab4t3/7t2F0+H2d2x9JnG76wXG6acfPApcy3fGocumll1Kr1di9ezcXXHABtVqNyclJ3vjGN1KW5bJzl+4hfuWVV/KmN70JgC1btgzTlA7SdN13/8rZ2Vne+MY3cvLJJ1Or1Wg0Gpx77rnccsstD6nc11xzDbt37+ZP//RPDzKGA6xatYr//b//97Jjf/EXf8GJJ55IGIasXbuWyy+//KC06meeeSYnnXQSt956K2eddRaVSoV169bxrne966Bn7Nq1iwsuuIBqtcrKlSt5/etfT5qmB523dN+PnTt3Mjk5CcBVV101rLel9XrfPV6KouBtb3sbW7duJQxDNm/ezO/8zu8c9KzNmzfz/Oc/n69//euceuqpRFHEUUcdxd/8zd8sOy/Pc6666iqOOeYYoihiYmKC008/nX/5l385uKKXMEiV+7WvfY1Xv/rVTExM0Gg0eNnLXsbc3Nwhy/LFL36RJz/5ycRxzDXXXAPAXXfdxUUXXcT4+DiVSoVf+qVf4v/7//6/Zdff3z4tP/7xj7nwwgsZHx8niiKe/OQn8/nPf/6gss7Pz/P617+ezZs3E4Yh69ev52UvexnT09N89atf5SlPeQoAL3/5y4dtMHjWofZp6XQ6vOENb2DDhg2EYchxxx3He97zHowxy84TQvDa176Wz372s5x00kmEYciJJ57IP//zPz9g3QJcf/31zMzM8Ju/+ZvLjl9++eV0Op2D6sjhcDgeTzzzmc8EbDT4z4L726/zW9/6Fs997nMZGRmhUqmwfft2/v3f//1B72eM4e1vfzvr16+nUqlw1lln8cMf/vCwyjKQae95z3t473vfy6ZNm4jjmO3bt/ODH/xg2bmDddmdd97JeeedR71e57//9/8OHL4sOtQe4vPz87zuda8bXnv00Ufzx3/8xwd5nWut+bM/+zNOPvlkoihicnKS5z73ucM0+UIIOp0OH//4xw9KbX9/e4g/3Ouv+9JsNvmXf/kXXvKSlwyN4QAve9nLqNVq/P3f//2D3sPhcDj+K3PppZeyfft2AC666CKEEMNMZve3B+knP/lJtm3bRhzHjI+P8+IXv5h77733QZ81Pz/PpZdeysjICKOjo1xyySUPaauzn/bbafBeP/7xj7n44otpNBpMTEzw27/92yRJsuzcwTfepz71qaE8G3zf3XzzzZx77rk0Gg1qtRrPetaz+I//+I9l1/+0a5Ldu3fzile8grVr1xKGIVu2bOE1r3kNWZbxsY99jIsuugiAs846ayibB8861B7iU1NTvOIVr2DVqlVEUcQpp5zCxz/+8WXnLF27fPjDHx7qIp7ylKdw4403PmDdAnzuc58jz/Nl7SOE4DWveQ27du3im9/85oPew+FwOBwPzONRfoONvH6o+5s73bTTTTse/zw2XXUc/6Uoy5JzzjmHpz71qbznPe/hy1/+Mn/yJ3/C1q1bec1rXnPIa174whfyk5/8hL/927/lve997zA6a2DsvS933XUXn/3sZ7nooovYsmUL+/fv55prrmH79u3ceuutrF279ojK/PnPf544jrnwwgsP6/wrr7ySq666irPPPpvXvOY13HbbbVx99dXceOO
"text/plain": [
"<Figure size 2000x500 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"initial_conditions_g = all_gather(initial_conditions)\n",
"lpt_field_g = all_gather(lpt_field)\n",
"ode_fields_g = [all_gather(p) for p in ode_fields]\n",
"\n",
"fields = {\"Initial Conditions\" : initial_conditions_g , \"LPT Field\" : lpt_field_g}\n",
"for i , field in enumerate(ode_fields_g):\n",
" fields[f\"field_{i}\"] = field\n",
"plot_fields_single_projection(fields,project_axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see how **easily we can switch solvers** here. Although Dopri5 offers adaptive stepping, it didnt yield a significant performance boost over Leapfrog in this case.\n",
"\n",
"> **Note**: Dopri5 uses a **PIDController** for adaptive stepping, which might face challenges in distributed setups. In my experience, it works well without triggering all-gathers, but make sure to set:\n",
"> ```python\n",
"> os.environ[\"EQX_ON_ERROR\"] = \"nan\"\n",
"> ```\n",
"> before importing `diffrax` to handle any errors gracefully.\n",
"\n",
"However, **Dopri5 requires more memory** than Leapfrog, making a $1024^3$ mesh simulation unfeasible on eight A100 GPUs with 80GB memory each!!. For larger setups, well need more compute resources—this is covered in the final notebook, **05-MultiHost_PM.ipynb**.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "a100",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}